SeerAttention: LLM内での固有スパース注意の学習
SeerAttention: Learning Intrinsic Sparse Attention in Your LLMs
October 17, 2024
著者: Yizhao Gao, Zhichen Zeng, Dayou Du, Shijie Cao, Hayden Kwok-Hay So, Ting Cao, Fan Yang, Mao Yang
cs.AI
要旨
近年の大規模言語モデル(LLM)において、注意機構は重要な要素となっています。しかしながら、その二次の計算量は、特に長いコンテキストウィンドウを持つLLMにおいて、効率性とスケーラビリティを制限しています。この制限に対処する有望なアプローチの1つは、注意のスパース性を活用することです。しかしながら、既存のスパース性に基づく解決策の多くは、スパース性を近似するために事前に定義されたパターンやヒューリスティックに依存しています。この手法は、言語タスクにおける注意のスパース性の動的性質を十分に捉えることができないという課題があります。本論文では、注意のスパース性は事前に定義するのではなく、学習すべきであると主張しています。このために、従来の注意機構に学習可能なゲートを追加し、注意マップ内の重要なブロックを適応的に選択し、残りのブロックをスパースと見なす新しいAttentionメカニズムであるSeerAttentionを設計しました。このブロックレベルのスパース性は、精度と高速化を効果的にバランスさせます。ゲーティングネットワークの効率的な学習を可能にするために、最小限のオーバーヘッドで注意マップのブロックレベルの正解を抽出するカスタマイズされたFlashAttention実装を開発しました。SeerAttentionは、事後トレーニングに適用されるだけでなく、長いコンテキストのファインチューニングにも優れています。実験結果は、事後トレーニング段階において、SeerAttentionが最先端の静的またはヒューリスティックに基づくスパース注意メソッドを大幅に上回ることを示し、さらに、異なるコンテキスト長やスパース率に適応する柔軟性と汎用性にも優れています。YaRNによる長いコンテキストのファインチューニングに適用すると、SeerAttentionは、最小の困惑度損失で32kコンテキスト長において90%のスパース率を達成し、FlashAttention-2に比べて5.67倍の高速化を実現します。
English
Attention is the cornerstone of modern Large Language Models (LLMs). Yet its
quadratic complexity limits the efficiency and scalability of LLMs, especially
for those with a long-context window. A promising approach addressing this
limitation is to leverage the sparsity in attention. However, existing
sparsity-based solutions predominantly rely on predefined patterns or
heuristics to approximate sparsity. This practice falls short to fully capture
the dynamic nature of attention sparsity in language-based tasks. This paper
argues that attention sparsity should be learned rather than predefined. To
this end, we design SeerAttention, a new Attention mechanism that augments the
conventional attention with a learnable gate that adaptively selects
significant blocks in an attention map and deems the rest blocks sparse. Such
block-level sparsity effectively balances accuracy and speedup. To enable
efficient learning of the gating network, we develop a customized
FlashAttention implementation that extracts the block-level ground truth of
attention map with minimum overhead. SeerAttention not only applies to
post-training, but also excels in long-context fine-tuning. Our results show
that at post-training stages, SeerAttention significantly outperforms
state-of-the-art static or heuristic-based sparse attention methods, while also
being more versatile and flexible to adapt to varying context lengths and
sparsity ratios. When applied to long-context fine-tuning with YaRN,
SeerAttention can achieve a remarkable 90% sparsity ratio at a 32k context
length with minimal perplexity loss, offering a 5.67x speedup over
FlashAttention-2.Summary
AI-Generated Summary