ChatPaper.aiChatPaper

La Mamba en la Llama: Destilación y Aceleración de Modelos Híbridos

The Mamba in the Llama: Distilling and Accelerating Hybrid Models

August 27, 2024
Autores: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI

Resumen

Las arquitecturas lineales de RNN, como Mamba, pueden ser competitivas con los modelos Transformer en modelado de lenguaje, al mismo tiempo que presentan características de implementación ventajosas. Dado el enfoque en el entrenamiento de modelos Transformer a gran escala, consideramos el desafío de convertir estos modelos preentrenados para su implementación. Demostramos que es factible destilar grandes Transformers en RNN lineales reutilizando los pesos de proyección lineal de las capas de atención con recursos académicos de GPU. El modelo híbrido resultante, que incorpora un cuarto de las capas de atención, logra un rendimiento comparable al Transformer original en pruebas de chat y supera a los modelos híbridos Mamba de código abierto entrenados desde cero con billones de tokens tanto en pruebas de chat como en pruebas generales. Además, presentamos un algoritmo de decodificación especulativa consciente del hardware que acelera la velocidad de inferencia de los modelos Mamba y híbridos. En general, mostramos cómo, con recursos computacionales limitados, podemos eliminar muchas de las capas de atención originales y generar a partir del modelo resultante de manera más eficiente. Nuestro modelo de mejor rendimiento, destilado de Llama3-8B-Instruct, logra una tasa de victoria controlada por longitud del 29.61 en AlpacaEval 2 contra GPT-4 y 7.35 en MT-Bench, superando al mejor modelo de RNN lineal ajustado a instrucciones.
English
Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.

Summary

AI-Generated Summary

PDF426November 16, 2024