ChatPaper.aiChatPaper

대규모 시퀀스에서의 빠른 인과적 어텐션: 스파스 플래시 어텐션을 통한 개선

Faster Causal Attention Over Large Sequences Through Sparse Flash Attention

June 1, 2023
저자: Matteo Pagliardini, Daniele Paliotta, Martin Jaggi, François Fleuret
cs.AI

초록

Transformer 기반 언어 모델은 점점 더 긴 시퀀스를 처리해야 하는 다양한 응용 분야에서 활용되고 있습니다. 이러한 응용 분야에서, 시퀀스 길이에 대해 2차적으로 스케일링되는 유일한 구성 요소인 인과적 자기 주의(causal self-attention)가 주요 관심사로 부각됩니다. 많은 연구에서 주의 패턴을 희소화(sparsify)하고 자기 주의의 계산 오버헤드를 줄이는 방안을 제안했지만, 이러한 방법들은 종종 구현상의 제약으로 인해 주의 행렬에 단순하고 정적인 구조를 부과하는 데 그치곤 합니다. 반면, 더 동적인 희소 주의를 구현하는 경우, Dao et al. (2022)의 Flash 구현을 사용하여 전체 주의를 계산하는 것보다 실행 시간이 현저히 느려지는 경우가 많습니다. 우리는 FlashAttention을 확장하여, 특히 키/쿼리 드롭핑(key/query dropping)과 해싱 기반 주의(hashing-based attention)를 포함하는 다양한 주의 희소 패턴을 수용할 수 있도록 했습니다. 이를 통해 계산 복잡성 오버헤드 없이 FlashAttention 위에서 다중 배수의 런타임 속도 향상을 달성했습니다. 비교적 낮은 희소도에서도, 시퀀스 길이가 증가함에 따라 우리의 방법은 FlashAttention보다 눈에 띄게 개선된 성능을 보입니다. perplexity를 희생하지 않으면서, 8k 및 16k 토큰 길이의 시퀀스에 대해 각각 2.0배 및 3.3배의 학습 속도 향상을 달성했습니다.
English
Transformer-based language models have found many diverse applications requiring them to process sequences of increasing length. For these applications, the causal self-attention -- which is the only component scaling quadratically w.r.t. the sequence length -- becomes a central concern. While many works have proposed schemes to sparsify the attention patterns and reduce the computational overhead of self-attention, those are often limited by implementations concerns and end up imposing a simple and static structure over the attention matrix. Conversely, implementing more dynamic sparse attentions often results in runtimes significantly slower than computing the full attention using the Flash implementation from Dao et al. (2022). We extend FlashAttention to accommodate a large class of attention sparsity patterns that, in particular, encompass key/query dropping and hashing-based attention. This leads to implementations with no computational complexity overhead and a multi-fold runtime speedup on top of FlashAttention. Even with relatively low degrees of sparsity, our method improves visibly upon FlashAttention as the sequence length increases. Without sacrificing perplexity, we increase the training speed of a transformer language model by 2.0times and 3.3times for sequences of respectively 8k and 16k tokens.
PDF12December 15, 2024