Attenzione Causale Più Veloce su Sequenze Ampie Tramite Sparse Flash Attention
Faster Causal Attention Over Large Sequences Through Sparse Flash Attention
June 1, 2023
Autori: Matteo Pagliardini, Daniele Paliotta, Martin Jaggi, François Fleuret
cs.AI
Abstract
I modelli linguistici basati su Transformer hanno trovato molteplici applicazioni che richiedono loro di elaborare sequenze di lunghezza crescente. Per queste applicazioni, l'attenzione causale (self-attention) — che è l'unico componente che scala quadraticamente rispetto alla lunghezza della sequenza — diventa una preoccupazione centrale. Sebbene molti lavori abbiano proposto schemi per sparsificare i pattern di attenzione e ridurre il sovraccarico computazionale della self-attention, questi sono spesso limitati da problemi di implementazione e finiscono per imporre una struttura semplice e statica sulla matrice di attenzione. Al contrario, implementare attenzioni sparse più dinamiche spesso si traduce in tempi di esecuzione significativamente più lenti rispetto al calcolo dell'attenzione completa utilizzando l'implementazione Flash di Dao et al. (2022). Estendiamo FlashAttention per supportare una vasta classe di pattern di sparsità dell'attenzione che, in particolare, includono l'eliminazione di chiavi/query e l'attenzione basata su hashing. Ciò porta a implementazioni senza sovraccarico di complessità computazionale e con un'accelerazione multipla del tempo di esecuzione rispetto a FlashAttention. Anche con gradi di sparsità relativamente bassi, il nostro metodo migliora visibilmente rispetto a FlashAttention all'aumentare della lunghezza della sequenza. Senza sacrificare la perplessità, aumentiamo la velocità di addestramento di un modello linguistico Transformer di 2,0 volte e 3,3 volte per sequenze rispettivamente di 8k e 16k token.
English
Transformer-based language models have found many diverse applications
requiring them to process sequences of increasing length. For these
applications, the causal self-attention -- which is the only component scaling
quadratically w.r.t. the sequence length -- becomes a central concern. While
many works have proposed schemes to sparsify the attention patterns and reduce
the computational overhead of self-attention, those are often limited by
implementations concerns and end up imposing a simple and static structure over
the attention matrix. Conversely, implementing more dynamic sparse attentions
often results in runtimes significantly slower than computing the full
attention using the Flash implementation from Dao et al. (2022). We extend
FlashAttention to accommodate a large class of attention sparsity patterns
that, in particular, encompass key/query dropping and hashing-based attention.
This leads to implementations with no computational complexity overhead and a
multi-fold runtime speedup on top of FlashAttention. Even with relatively low
degrees of sparsity, our method improves visibly upon FlashAttention as the
sequence length increases. Without sacrificing perplexity, we increase the
training speed of a transformer language model by 2.0times and 3.3times
for sequences of respectively 8k and 16k tokens.