Il Mamba nel Llama: Distillazione e Accelerazione dei Modelli Ibridi
The Mamba in the Llama: Distilling and Accelerating Hybrid Models
August 27, 2024
Autori: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI
Abstract
Le architetture RNN lineari, come Mamba, possono essere competitive con i modelli Transformer nel language modeling pur avendo caratteristiche di implementazione vantaggiose. Data l'attenzione posta all'addestramento di modelli Transformer su larga scala, consideriamo la sfida della conversione di tali modelli preaddestrati per l'implementazione. Dimostriamo che è fattibile distillare grandi Transformer in RNN lineari riutilizzando i pesi di proiezione lineare dai livelli di attenzione con risorse GPU accademiche. Il modello ibrido risultante, che incorpora un quarto dei livelli di attenzione, raggiunge prestazioni paragonabili all'originale Transformer nei benchmark di chat e supera i modelli ibridi Mamba open-source addestrati da zero con trilioni di token sia nei benchmark di chat che in quelli generali. Inoltre, introduciamo un algoritmo di decodifica speculativa consapevole dell'hardware che accelera la velocità di inferenza di Mamba e dei modelli ibridi. Nel complesso mostriamo come, con risorse di calcolo limitate, possiamo rimuovere molti dei livelli di attenzione originali e generare in modo più efficiente dal modello risultante. Il nostro modello di punta, distillato da Llama3-8B-Instruct, raggiunge un tasso di vittoria controllato dalla lunghezza del 29,61 su AlpacaEval 2 contro GPT-4 e del 7,35 su MT-Bench, superando il miglior modello RNN lineare ottimizzato per le istruzioni.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer
models in language modeling while having advantageous deployment
characteristics. Given the focus on training large-scale Transformer models, we
consider the challenge of converting these pretrained models for deployment. We
demonstrate that it is feasible to distill large Transformers into linear RNNs
by reusing the linear projection weights from attention layers with academic
GPU resources. The resulting hybrid model, which incorporates a quarter of the
attention layers, achieves performance comparable to the original Transformer
in chat benchmarks and outperforms open-source hybrid Mamba models trained from
scratch with trillions of tokens in both chat benchmarks and general
benchmarks. Moreover, we introduce a hardware-aware speculative decoding
algorithm that accelerates the inference speed of Mamba and hybrid models.
Overall we show how, with limited computation resources, we can remove many of
the original attention layers and generate from the resulting model more
efficiently. Our top-performing model, distilled from Llama3-8B-Instruct,
achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and
7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.