A Mamba na Lhama: Destilação e Aceleração de Modelos Híbridos
The Mamba in the Llama: Distilling and Accelerating Hybrid Models
August 27, 2024
Autores: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI
Resumo
Arquiteturas lineares de RNN, como Mamba, podem ser competitivas com modelos Transformer na modelagem de linguagem, ao mesmo tempo em que possuem características vantajosas para implantação. Dado o foco no treinamento de modelos Transformer em larga escala, consideramos o desafio de converter esses modelos pré-treinados para implantação. Demonstramos que é viável destilar grandes Transformers em RNNs lineares reutilizando os pesos de projeção linear das camadas de atenção com recursos acadêmicos de GPU. O modelo híbrido resultante, que incorpora um quarto das camadas de atenção, alcança desempenho comparável ao Transformer original em benchmarks de chat e supera modelos híbridos Mamba de código aberto treinados do zero com trilhões de tokens, tanto em benchmarks de chat quanto em benchmarks gerais. Além disso, introduzimos um algoritmo de decodificação especulativa consciente do hardware que acelera a velocidade de inferência de modelos Mamba e híbridos. No geral, mostramos como, com recursos computacionais limitados, podemos remover muitas das camadas de atenção originais e gerar a partir do modelo resultante de forma mais eficiente. Nosso modelo de melhor desempenho, destilado do Llama3-8B-Instruct, alcança uma taxa de vitória controlada por comprimento de 29,61 no AlpacaEval 2 contra o GPT-4 e 7,35 no MT-Bench, superando o melhor modelo de RNN linear ajustado para instruções.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer
models in language modeling while having advantageous deployment
characteristics. Given the focus on training large-scale Transformer models, we
consider the challenge of converting these pretrained models for deployment. We
demonstrate that it is feasible to distill large Transformers into linear RNNs
by reusing the linear projection weights from attention layers with academic
GPU resources. The resulting hybrid model, which incorporates a quarter of the
attention layers, achieves performance comparable to the original Transformer
in chat benchmarks and outperforms open-source hybrid Mamba models trained from
scratch with trillions of tokens in both chat benchmarks and general
benchmarks. Moreover, we introduce a hardware-aware speculative decoding
algorithm that accelerates the inference speed of Mamba and hybrid models.
Overall we show how, with limited computation resources, we can remove many of
the original attention layers and generate from the resulting model more
efficiently. Our top-performing model, distilled from Llama3-8B-Instruct,
achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and
7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.Summary
AI-Generated Summary