シンプルな線形アテンション言語モデルは、リコールとスループットのトレードオフをバランスさせる。
Simple linear attention language models balance the recall-throughput tradeoff
February 28, 2024
著者: Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, Dylan Zinsley, James Zou, Atri Rudra, Christopher Ré
cs.AI
要旨
最近の研究では、アテンションベースの言語モデルがリコール、すなわちコンテキスト内で以前に見たトークンを生成に反映する能力において優れていることが示されています。しかし、アテンションベースモデルの効率性は、推論時にKVキャッシュのメモリ消費が急激に増加することによってボトルネックとなっています。本研究では、リコールを損なうことなく言語モデルの効率性(例えば、メモリ消費の削減)を向上できるかどうかを探ります。幅広いアーキテクチャに対して実験と理論を適用することで、モデルの状態サイズとリコール能力の間に重要なトレードオフがあることを明らかにします。アテンションの効率的な代替手法(例えば、H3、Mamba、RWKV)は固定サイズのリカレント状態を維持しますが、リコールにおいて苦戦することが分かります。我々は、線形アテンションとスライディングウィンドウアテンションを組み合わせたシンプルなアーキテクチャであるBASEDを提案します。BASEDのウィンドウサイズと線形アテンションの特徴次元を変化させることで、状態サイズを調整し、リコールとメモリのトレードオフ曲線のパレートフロンティアを探索できます。一方の端ではアテンションの完全な品質を回復し、もう一方の端ではアテンション代替手法の小さな状態サイズを実現します。1.3bパラメータまでの言語モデルを学習し、BASEDが最も強力なサブクアドラティックモデル(例えば、Mamba)とパープレキシティにおいて同等であり、現実世界のリコール集約型タスクでは6.22ポイントの精度で優れていることを示します。線形アテンションの実装は、最適化された標準アテンション実装よりも効率が低いことが多いです。BASEDを競争力のあるものにするために、1.3bパラメータモデルを使用して1024トークンを生成する際に、FlashAttention-2よりも24倍高いスループットを実現するIOを意識したアルゴリズムを開発しました。本研究のコードは以下で提供されています:https://github.com/HazyResearch/based。
English
Recent work has shown that attention-based language models excel at recall,
the ability to ground generations in tokens previously seen in context.
However, the efficiency of attention-based models is bottle-necked during
inference by the KV-cache's aggressive memory consumption. In this work, we
explore whether we can improve language model efficiency (e.g. by reducing
memory consumption) without compromising on recall. By applying experiments and
theory to a broad set of architectures, we identify a key tradeoff between a
model's state size and recall ability. We show that efficient alternatives to
attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but
struggle at recall. We propose BASED a simple architecture combining linear and
sliding window attention. By varying BASED window size and linear attention
feature dimension, we can dial the state size and traverse the pareto frontier
of the recall-memory tradeoff curve, recovering the full quality of attention
on one end and the small state size of attention-alternatives on the other. We
train language models up to 1.3b parameters and show that BASED matches the
strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them
on real-world recall-intensive tasks by 6.22 accuracy points. Implementations
of linear attention are often less efficient than optimized standard attention
implementations. To make BASED competitive, we develop IO-aware algorithms that
enable 24x higher throughput on language generation than FlashAttention-2, when
generating 1024 tokens using 1.3b parameter models. Code for this work is
provided at: https://github.com/HazyResearch/based.