ChatPaper.aiChatPaper

A Mamba na Lhama: Destilação e Aceleração de Modelos Híbridos

The Mamba in the Llama: Distilling and Accelerating Hybrid Models

August 27, 2024
Autores: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI

Resumo

Arquiteturas lineares de RNN, como Mamba, podem ser competitivas com modelos Transformer na modelagem de linguagem, ao mesmo tempo em que possuem características vantajosas para implantação. Dado o foco no treinamento de modelos Transformer em larga escala, consideramos o desafio de converter esses modelos pré-treinados para implantação. Demonstramos que é viável destilar grandes Transformers em RNNs lineares reutilizando os pesos de projeção linear das camadas de atenção com recursos acadêmicos de GPU. O modelo híbrido resultante, que incorpora um quarto das camadas de atenção, alcança desempenho comparável ao Transformer original em benchmarks de chat e supera modelos híbridos Mamba de código aberto treinados do zero com trilhões de tokens, tanto em benchmarks de chat quanto em benchmarks gerais. Além disso, introduzimos um algoritmo de decodificação especulativa consciente do hardware que acelera a velocidade de inferência de modelos Mamba e híbridos. No geral, mostramos como, com recursos computacionais limitados, podemos remover muitas das camadas de atenção originais e gerar a partir do modelo resultante de forma mais eficiente. Nosso modelo de melhor desempenho, destilado do Llama3-8B-Instruct, alcança uma taxa de vitória controlada por comprimento de 29,61 no AlpacaEval 2 contra o GPT-4 e 7,35 no MT-Bench, superando o melhor modelo de RNN linear ajustado para instruções.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.

Summary

AI-Generated Summary

PDF426November 16, 2024