ラマの中のマンバ:ハイブリッドモデルの蒸留と加速化
The Mamba in the Llama: Distilling and Accelerating Hybrid Models
August 27, 2024
著者: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao
cs.AI
要旨
MambaなどのLinear RNNアーキテクチャは、有利な展開特性を持ちながら、言語モデリングにおいてTransformerモデルと競争力を持つことができます。大規模Transformerモデルのトレーニングに焦点を当てる中で、これらの事前学習済みモデルを展開用に変換する課題を考慮します。我々は、アカデミックGPUリソースを使用して、アテンション層からの線形射影重みを再利用することで、大規模なTransformerをLinear RNNに蒸留することが可能であることを示します。アテンション層の四分の一を組み込んだ結果のハイブリッドモデルは、チャットベンチマークにおいて元のTransformerと比較可能な性能を達成し、オープンソースのハイブリッドMambaモデルをトレーニング済みのトリリオンのトークンよりも優れた性能を示します。さらに、Mambaおよびハイブリッドモデルの推論速度を加速するハードウェアに適した推測デコーディングアルゴリズムを導入します。総じて、限られた計算リソースで、多くの元のアテンション層を削除し、より効率的にモデルを生成できることを示します。Llama3-8B-Instructから蒸留された最高性能モデルは、AlpacaEval 2においてGPT-4に対して29.61の長さ制御された勝率を達成し、MT-Benchでは7.35を記録し、最高の命令に調整された線形RNNモデルを上回ります。
English
Linear RNN architectures, like Mamba, can be competitive with Transformer
models in language modeling while having advantageous deployment
characteristics. Given the focus on training large-scale Transformer models, we
consider the challenge of converting these pretrained models for deployment. We
demonstrate that it is feasible to distill large Transformers into linear RNNs
by reusing the linear projection weights from attention layers with academic
GPU resources. The resulting hybrid model, which incorporates a quarter of the
attention layers, achieves performance comparable to the original Transformer
in chat benchmarks and outperforms open-source hybrid Mamba models trained from
scratch with trillions of tokens in both chat benchmarks and general
benchmarks. Moreover, we introduce a hardware-aware speculative decoding
algorithm that accelerates the inference speed of Mamba and hybrid models.
Overall we show how, with limited computation resources, we can remove many of
the original attention layers and generate from the resulting model more
efficiently. Our top-performing model, distilled from Llama3-8B-Instruct,
achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and
7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.Summary
AI-Generated Summary