ChatPaper.aiChatPaper

Attenzione Causale Più Veloce su Sequenze Ampie Tramite Sparse Flash Attention

Faster Causal Attention Over Large Sequences Through Sparse Flash Attention

June 1, 2023
Autori: Matteo Pagliardini, Daniele Paliotta, Martin Jaggi, François Fleuret
cs.AI

Abstract

I modelli linguistici basati su Transformer hanno trovato molteplici applicazioni che richiedono loro di elaborare sequenze di lunghezza crescente. Per queste applicazioni, l'attenzione causale (self-attention) — che è l'unico componente che scala quadraticamente rispetto alla lunghezza della sequenza — diventa una preoccupazione centrale. Sebbene molti lavori abbiano proposto schemi per sparsificare i pattern di attenzione e ridurre il sovraccarico computazionale della self-attention, questi sono spesso limitati da problemi di implementazione e finiscono per imporre una struttura semplice e statica sulla matrice di attenzione. Al contrario, implementare attenzioni sparse più dinamiche spesso si traduce in tempi di esecuzione significativamente più lenti rispetto al calcolo dell'attenzione completa utilizzando l'implementazione Flash di Dao et al. (2022). Estendiamo FlashAttention per supportare una vasta classe di pattern di sparsità dell'attenzione che, in particolare, includono l'eliminazione di chiavi/query e l'attenzione basata su hashing. Ciò porta a implementazioni senza sovraccarico di complessità computazionale e con un'accelerazione multipla del tempo di esecuzione rispetto a FlashAttention. Anche con gradi di sparsità relativamente bassi, il nostro metodo migliora visibilmente rispetto a FlashAttention all'aumentare della lunghezza della sequenza. Senza sacrificare la perplessità, aumentiamo la velocità di addestramento di un modello linguistico Transformer di 2,0 volte e 3,3 volte per sequenze rispettivamente di 8k e 16k token.
English
Transformer-based language models have found many diverse applications requiring them to process sequences of increasing length. For these applications, the causal self-attention -- which is the only component scaling quadratically w.r.t. the sequence length -- becomes a central concern. While many works have proposed schemes to sparsify the attention patterns and reduce the computational overhead of self-attention, those are often limited by implementations concerns and end up imposing a simple and static structure over the attention matrix. Conversely, implementing more dynamic sparse attentions often results in runtimes significantly slower than computing the full attention using the Flash implementation from Dao et al. (2022). We extend FlashAttention to accommodate a large class of attention sparsity patterns that, in particular, encompass key/query dropping and hashing-based attention. This leads to implementations with no computational complexity overhead and a multi-fold runtime speedup on top of FlashAttention. Even with relatively low degrees of sparsity, our method improves visibly upon FlashAttention as the sequence length increases. Without sacrificing perplexity, we increase the training speed of a transformer language model by 2.0times and 3.3times for sequences of respectively 8k and 16k tokens.
PDF12March 25, 2026