通過稀疏閃爍注意力在大序列上實現更快的因果關注。
Faster Causal Attention Over Large Sequences Through Sparse Flash Attention
June 1, 2023
作者: Matteo Pagliardini, Daniele Paliotta, Martin Jaggi, François Fleuret
cs.AI
摘要
基於Transformer的語言模型已經被廣泛應用在許多不同領域,需要處理越來越長的序列。對於這些應用,因果自注意力成為一個核心關注點,因為它是唯一一個隨著序列長度呈二次方擴展的組件。雖然許多研究提出了稀疏化注意力模式並減少自注意力的計算負擔的方案,但這些方案通常受到實現問題的限制,最終導致在注意力矩陣上實施簡單且靜態的結構。相反,實現更動態的稀疏注意力通常會導致運行時間顯著慢於使用Dao等人(2022年)的Flash實現計算完整注意力。我們擴展了FlashAttention以容納一大類注意力稀疏模式,特別包括鍵/查詢丟棄和基於哈希的注意力。這導致實現沒有計算複雜度開銷,並在FlashAttention的基礎上實現多倍的運行時加速。即使在相對較低的稀疏度下,我們的方法隨著序列長度的增加,在FlashAttention的基礎上有顯著改進。在不犧牲困惑度的情況下,我們將Transformer語言模型的訓練速度分別提高了2.0倍和3.3倍,適用於8k和16k令牌的序列。
English
Transformer-based language models have found many diverse applications
requiring them to process sequences of increasing length. For these
applications, the causal self-attention -- which is the only component scaling
quadratically w.r.t. the sequence length -- becomes a central concern. While
many works have proposed schemes to sparsify the attention patterns and reduce
the computational overhead of self-attention, those are often limited by
implementations concerns and end up imposing a simple and static structure over
the attention matrix. Conversely, implementing more dynamic sparse attentions
often results in runtimes significantly slower than computing the full
attention using the Flash implementation from Dao et al. (2022). We extend
FlashAttention to accommodate a large class of attention sparsity patterns
that, in particular, encompass key/query dropping and hashing-based attention.
This leads to implementations with no computational complexity overhead and a
multi-fold runtime speedup on top of FlashAttention. Even with relatively low
degrees of sparsity, our method improves visibly upon FlashAttention as the
sequence length increases. Without sacrificing perplexity, we increase the
training speed of a transformer language model by 2.0times and 3.3times
for sequences of respectively 8k and 16k tokens.