ChatPaper.aiChatPaper

大規模シーケンスにおける高速な因果的注意機構:スパースフラッシュアテンションを通じて

Faster Causal Attention Over Large Sequences Through Sparse Flash Attention

June 1, 2023
著者: Matteo Pagliardini, Daniele Paliotta, Martin Jaggi, François Fleuret
cs.AI

要旨

Transformerベースの言語モデルは、ますます長いシーケンスを処理する必要がある多様なアプリケーションで活用されています。これらのアプリケーションにおいて、シーケンス長に対して二次的にスケーリングする唯一のコンポーネントである因果的セルフアテンションが中心的な課題となっています。多くの研究がアテンションパターンをスパース化し、セルフアテンションの計算オーバーヘッドを削減する手法を提案していますが、それらは実装上の制約により、アテンションマトリックスに単純で静的な構造を課すことが多いです。一方、より動的なスパースアテンションを実装すると、Daoら(2022)のFlash実装を使用して完全なアテンションを計算するよりも、実行時間が大幅に遅くなることがよくあります。本研究では、FlashAttentionを拡張し、特にキー/クエリのドロップやハッシュベースのアテンションを含む、幅広いアテンションのスパースパターンをサポートします。これにより、計算量のオーバーヘッドなしに実装が可能となり、FlashAttentionを上回る複数倍の実行速度向上を実現します。比較的低いスパース度であっても、シーケンス長が増加するにつれて、本手法はFlashAttentionを目に見えて改善します。パープレキシティを犠牲にすることなく、Transformer言語モデルのトレーニング速度を、それぞれ8kおよび16kトークンのシーケンスに対して2.0倍および3.3倍向上させました。
English
Transformer-based language models have found many diverse applications requiring them to process sequences of increasing length. For these applications, the causal self-attention -- which is the only component scaling quadratically w.r.t. the sequence length -- becomes a central concern. While many works have proposed schemes to sparsify the attention patterns and reduce the computational overhead of self-attention, those are often limited by implementations concerns and end up imposing a simple and static structure over the attention matrix. Conversely, implementing more dynamic sparse attentions often results in runtimes significantly slower than computing the full attention using the Flash implementation from Dao et al. (2022). We extend FlashAttention to accommodate a large class of attention sparsity patterns that, in particular, encompass key/query dropping and hashing-based attention. This leads to implementations with no computational complexity overhead and a multi-fold runtime speedup on top of FlashAttention. Even with relatively low degrees of sparsity, our method improves visibly upon FlashAttention as the sequence length increases. Without sacrificing perplexity, we increase the training speed of a transformer language model by 2.0times and 3.3times for sequences of respectively 8k and 16k tokens.
PDF12December 15, 2024