ChatPaper.aiChatPaper

Die Mamba im Lama: Destillation und Beschleunigung hybrider Modelle

The Mamba in the Llama: Distilling and Accelerating Hybrid Models

August 27, 2024
Autoren: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI

Zusammenfassung

Lineare RNN-Architekturen wie Mamba können im Bereich der Sprachmodellierung wettbewerbsfähig mit Transformer-Modellen sein und dabei vorteilhafte Bereitstellungseigenschaften aufweisen. Angesichts des Fokus auf das Training von groß angelegten Transformer-Modellen betrachten wir die Herausforderung, diese vorab trainierten Modelle für den Einsatz umzuwandeln. Wir zeigen, dass es machbar ist, große Transformer in lineare RNNs zu destillieren, indem wir die linearen Projektionsgewichte aus den Aufmerksamkeitsschichten mit akademischen GPU-Ressourcen wiederverwenden. Das resultierende Hybridmodell, das ein Viertel der Aufmerksamkeitsschichten integriert, erzielt eine vergleichbare Leistung wie der originale Transformer in Chat-Benchmarks und übertrifft Open-Source-Hybridmodelle von Mamba, die von Grund auf mit Billionen von Tokens trainiert wurden, sowohl in Chat-Benchmarks als auch in allgemeinen Benchmarks. Darüber hinaus stellen wir einen hardwarebewussten spekulativen Decodierungsalgorithmus vor, der die Inferenzgeschwindigkeit von Mamba und Hybridmodellen beschleunigt. Insgesamt zeigen wir, wie wir mit begrenzten Rechenressourcen viele der ursprünglichen Aufmerksamkeitsschichten entfernen und aus dem resultierenden Modell effizienter generieren können. Unser leistungsstärkstes Modell, destilliert aus Llama3-8B-Instruct, erzielt eine Gewinnrate von 29,61 bei Längensteuerung in AlpacaEval 2 gegenüber GPT-4 und 7,35 bei MT-Bench und übertrifft das beste anweisungsgesteuerte lineare RNN-Modell.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.

Summary

AI-Generated Summary

PDF426November 16, 2024