De Mamba in de Llama: Destilleren en Versnellen van Hybride Modellen
The Mamba in the Llama: Distilling and Accelerating Hybrid Models
August 27, 2024
Auteurs: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI
Samenvatting
Lineaire RNN-architecturen, zoals Mamba, kunnen concurreren met Transformer-modellen in taalmodelvorming, terwijl ze gunstige implementatiekenmerken hebben. Gezien de focus op het trainen van grootschalige Transformer-modellen, beschouwen we de uitdaging om deze voorgetrainde modellen om te zetten voor implementatie. We tonen aan dat het haalbaar is om grote Transformers te destilleren naar lineaire RNN's door de lineaire projectiegewichten uit aandachtslagen te hergebruiken met academische GPU-bronnen. Het resulterende hybride model, dat een kwart van de aandachtslagen bevat, bereikt prestaties die vergelijkbaar zijn met de oorspronkelijke Transformer in chatbenchmarks en overtreft open-source hybride Mamba-modellen die vanaf nul zijn getraind met biljoenen tokens, zowel in chatbenchmarks als in algemene benchmarks. Bovendien introduceren we een hardwarebewust speculatief decodeeralgoritme dat de inferentiesnelheid van Mamba en hybride modellen versnelt. Over het geheel genomen laten we zien hoe we, met beperkte rekenbronnen, veel van de oorspronkelijke aandachtslagen kunnen verwijderen en efficiënter kunnen genereren uit het resulterende model. Ons best presterende model, gedestilleerd uit Llama3-8B-Instruct, behaalt een lengtegecontroleerde winratio van 29,61 op AlpacaEval 2 tegenover GPT-4 en 7,35 op MT-Bench, waarmee het het best presterende instructiegetrainde lineaire RNN-model overtreft.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer
models in language modeling while having advantageous deployment
characteristics. Given the focus on training large-scale Transformer models, we
consider the challenge of converting these pretrained models for deployment. We
demonstrate that it is feasible to distill large Transformers into linear RNNs
by reusing the linear projection weights from attention layers with academic
GPU resources. The resulting hybrid model, which incorporates a quarter of the
attention layers, achieves performance comparable to the original Transformer
in chat benchmarks and outperforms open-source hybrid Mamba models trained from
scratch with trillions of tokens in both chat benchmarks and general
benchmarks. Moreover, we introduce a hardware-aware speculative decoding
algorithm that accelerates the inference speed of Mamba and hybrid models.
Overall we show how, with limited computation resources, we can remove many of
the original attention layers and generate from the resulting model more
efficiently. Our top-performing model, distilled from Llama3-8B-Instruct,
achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and
7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.Summary
AI-Generated Summary