ChatPaper.aiChatPaper

Il Mamba nel Llama: Distillazione e Accelerazione dei Modelli Ibridi

The Mamba in the Llama: Distilling and Accelerating Hybrid Models

August 27, 2024
Autori: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI

Abstract

Le architetture RNN lineari, come Mamba, possono essere competitive con i modelli Transformer nel language modeling pur avendo caratteristiche di implementazione vantaggiose. Data l'attenzione posta all'addestramento di modelli Transformer su larga scala, consideriamo la sfida della conversione di tali modelli preaddestrati per l'implementazione. Dimostriamo che è fattibile distillare grandi Transformer in RNN lineari riutilizzando i pesi di proiezione lineare dai livelli di attenzione con risorse GPU accademiche. Il modello ibrido risultante, che incorpora un quarto dei livelli di attenzione, raggiunge prestazioni paragonabili all'originale Transformer nei benchmark di chat e supera i modelli ibridi Mamba open-source addestrati da zero con trilioni di token sia nei benchmark di chat che in quelli generali. Inoltre, introduciamo un algoritmo di decodifica speculativa consapevole dell'hardware che accelera la velocità di inferenza di Mamba e dei modelli ibridi. Nel complesso mostriamo come, con risorse di calcolo limitate, possiamo rimuovere molti dei livelli di attenzione originali e generare in modo più efficiente dal modello risultante. Il nostro modello di punta, distillato da Llama3-8B-Instruct, raggiunge un tasso di vittoria controllato dalla lunghezza del 29,61 su AlpacaEval 2 contro GPT-4 e del 7,35 su MT-Bench, superando il miglior modello RNN lineare ottimizzato per le istruzioni.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.
PDF426November 16, 2024