Die Mamba im Lama: Destillation und Beschleunigung hybrider Modelle
The Mamba in the Llama: Distilling and Accelerating Hybrid Models
August 27, 2024
Autoren: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI
Zusammenfassung
Lineare RNN-Architekturen wie Mamba können im Bereich der Sprachmodellierung wettbewerbsfähig mit Transformer-Modellen sein und dabei vorteilhafte Bereitstellungseigenschaften aufweisen. Angesichts des Fokus auf das Training von groß angelegten Transformer-Modellen betrachten wir die Herausforderung, diese vorab trainierten Modelle für den Einsatz umzuwandeln. Wir zeigen, dass es machbar ist, große Transformer in lineare RNNs zu destillieren, indem wir die linearen Projektionsgewichte aus den Aufmerksamkeitsschichten mit akademischen GPU-Ressourcen wiederverwenden. Das resultierende Hybridmodell, das ein Viertel der Aufmerksamkeitsschichten integriert, erzielt eine vergleichbare Leistung wie der originale Transformer in Chat-Benchmarks und übertrifft Open-Source-Hybridmodelle von Mamba, die von Grund auf mit Billionen von Tokens trainiert wurden, sowohl in Chat-Benchmarks als auch in allgemeinen Benchmarks. Darüber hinaus stellen wir einen hardwarebewussten spekulativen Decodierungsalgorithmus vor, der die Inferenzgeschwindigkeit von Mamba und Hybridmodellen beschleunigt. Insgesamt zeigen wir, wie wir mit begrenzten Rechenressourcen viele der ursprünglichen Aufmerksamkeitsschichten entfernen und aus dem resultierenden Modell effizienter generieren können. Unser leistungsstärkstes Modell, destilliert aus Llama3-8B-Instruct, erzielt eine Gewinnrate von 29,61 bei Längensteuerung in AlpacaEval 2 gegenüber GPT-4 und 7,35 bei MT-Bench und übertrifft das beste anweisungsgesteuerte lineare RNN-Modell.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer
models in language modeling while having advantageous deployment
characteristics. Given the focus on training large-scale Transformer models, we
consider the challenge of converting these pretrained models for deployment. We
demonstrate that it is feasible to distill large Transformers into linear RNNs
by reusing the linear projection weights from attention layers with academic
GPU resources. The resulting hybrid model, which incorporates a quarter of the
attention layers, achieves performance comparable to the original Transformer
in chat benchmarks and outperforms open-source hybrid Mamba models trained from
scratch with trillions of tokens in both chat benchmarks and general
benchmarks. Moreover, we introduce a hardware-aware speculative decoding
algorithm that accelerates the inference speed of Mamba and hybrid models.
Overall we show how, with limited computation resources, we can remove many of
the original attention layers and generate from the resulting model more
efficiently. Our top-performing model, distilled from Llama3-8B-Instruct,
achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and
7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.Summary
AI-Generated Summary