ChatPaper.aiChatPaper

De Mamba in de Llama: Destilleren en Versnellen van Hybride Modellen

The Mamba in the Llama: Distilling and Accelerating Hybrid Models

August 27, 2024
Auteurs: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI

Samenvatting

Lineaire RNN-architecturen, zoals Mamba, kunnen concurreren met Transformer-modellen in taalmodelvorming, terwijl ze gunstige implementatiekenmerken hebben. Gezien de focus op het trainen van grootschalige Transformer-modellen, beschouwen we de uitdaging om deze voorgetrainde modellen om te zetten voor implementatie. We tonen aan dat het haalbaar is om grote Transformers te destilleren naar lineaire RNN's door de lineaire projectiegewichten uit aandachtslagen te hergebruiken met academische GPU-bronnen. Het resulterende hybride model, dat een kwart van de aandachtslagen bevat, bereikt prestaties die vergelijkbaar zijn met de oorspronkelijke Transformer in chatbenchmarks en overtreft open-source hybride Mamba-modellen die vanaf nul zijn getraind met biljoenen tokens, zowel in chatbenchmarks als in algemene benchmarks. Bovendien introduceren we een hardwarebewust speculatief decodeeralgoritme dat de inferentiesnelheid van Mamba en hybride modellen versnelt. Over het geheel genomen laten we zien hoe we, met beperkte rekenbronnen, veel van de oorspronkelijke aandachtslagen kunnen verwijderen en efficiënter kunnen genereren uit het resulterende model. Ons best presterende model, gedestilleerd uit Llama3-8B-Instruct, behaalt een lengtegecontroleerde winratio van 29,61 op AlpacaEval 2 tegenover GPT-4 en 7,35 op MT-Bench, waarmee het het best presterende instructiegetrainde lineaire RNN-model overtreft.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.

Summary

AI-Generated Summary

PDF426November 16, 2024