ChatPaper.aiChatPaper

Le Mamba dans le Llama : Distillation et Accélération des Modèles Hybrides

The Mamba in the Llama: Distilling and Accelerating Hybrid Models

August 27, 2024
Auteurs: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI

Résumé

Les architectures RNN linéaires, telles que Mamba, peuvent être compétitives avec les modèles Transformer en modélisation de langage tout en présentant des caractéristiques de déploiement avantageuses. Étant donné l'accent mis sur l'entraînement de grands modèles Transformer, nous examinons le défi de convertir ces modèles pré-entraînés pour le déploiement. Nous démontrons qu'il est possible de distiller de grands Transformers en RNN linéaires en réutilisant les poids de projection linéaire des couches d'attention avec des ressources GPU académiques. Le modèle hybride résultant, qui intègre un quart des couches d'attention, atteint des performances comparables à celles du Transformer d'origine dans les benchmarks de chat et surpasse les modèles hybrides Mamba open-source entraînés à partir de zéro avec des billions de jetons, à la fois dans les benchmarks de chat et généraux. De plus, nous introduisons un algorithme de décodage spéculatif conscient du matériel qui accélère la vitesse d'inférence des modèles Mamba et hybrides. Dans l'ensemble, nous montrons comment, avec des ressources de calcul limitées, nous pouvons supprimer bon nombre des couches d'attention d'origine et générer à partir du modèle résultant de manière plus efficace. Notre modèle le plus performant, distillé à partir de Llama3-8B-Instruct, atteint un taux de victoire contrôlé par la longueur de 29,61 sur AlpacaEval 2 contre GPT-4 et 7,35 sur MT-Bench, surpassant le meilleur modèle RNN linéaire ajusté aux instructions.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.

Summary

AI-Generated Summary

PDF426November 16, 2024