ChatPaper.aiChatPaper

羚羊中的眼镜蛇:蒸馏和加速混合模型

The Mamba in the Llama: Distilling and Accelerating Hybrid Models

August 27, 2024
作者: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI

摘要

线性RNN架构,如Mamba,在语言建模方面可以与Transformer模型竞争,同时具有有利的部署特性。鉴于目前对训练大规模Transformer模型的关注,我们考虑将这些预训练模型转换为部署模型的挑战。我们展示了通过重复使用来自注意力层的线性投影权重,将大型Transformer蒸馏为线性RNN是可行的,使用学术GPU资源。由此产生的混合模型,其中包含四分之一的注意力层,实现了与原始Transformer在聊天基准测试中可比的性能,并且在聊天基准测试和通用基准测试中胜过从头开始训练的开源混合Mamba模型,后者使用了数万亿个标记。此外,我们引入了一种硬件感知的推测解码算法,加速了Mamba和混合模型的推理速度。总体而言,我们展示了如何在有限的计算资源下,可以去除许多原始注意力层,并更高效地生成从结果模型。我们的表现最佳模型,从Llama3-8B-Instruct蒸馏而来,在AlpacaEval 2上实现了29.61的长度控制胜率,超过了GPT-4,以及在MT-Bench上的7.35分,超越了最佳指令调整的线性RNN模型。
English
Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.

Summary

AI-Generated Summary

PDF426November 16, 2024