BurstAttention: 극도로 긴 시퀀스를 위한 효율적인 분산 어텐션 프레임워크
BurstAttention: An Efficient Distributed Attention Framework for Extremely Long Sequences
March 14, 2024
저자: Sun Ao, Weilin Zhao, Xu Han, Cheng Yang, Zhiyuan Liu, Chuan Shi, Maosong Sun, Shengnan Wang, Teng Su
cs.AI
초록
효과적인 어텐션 모듈은 Transformer 기반 대규모 언어 모델(LLMs)의 성공에 중요한 역할을 해왔지만, 이러한 어텐션 모듈의 2차 시간 및 메모리 복잡도는 긴 시퀀스를 처리할 때 문제를 야기합니다. 긴 시퀀스 문제에 대한 한 가지 잠재적인 해결책은 분산 클러스터를 활용하여 어텐션 모듈의 계산을 여러 장치(예: GPU)에 걸쳐 병렬화하는 것입니다. 그러나 분산 접근 방식을 채택하면 로컬 어텐션 결과를 저장하기 위한 추가 메모리 오버헤드가 발생하고, 로컬 결과를 글로벌 결과로 집계하기 위한 추가 통신 비용이 불가피하게 발생합니다. 본 논문에서는 글로벌 클러스터 및 로컬 장치 수준에서 메모리 접근 및 통신 작업을 최적화하기 위해 "BurstAttention"이라는 분산 어텐션 프레임워크를 제안합니다. 실험에서는 긴 시퀀스 처리를 위한 다른 경쟁적인 분산 어텐션 솔루션과 BurstAttention을 비교합니다. 다양한 길이 설정에서의 실험 결과는 BurstAttention이 이러한 경쟁적인 베이스라인에 비해 긴 시퀀스 처리에 있어 상당한 이점을 제공하며, 8개의 A100에서 32K 시퀀스 길이를 학습하는 동안 40%의 통신 오버헤드를 줄이고 2배의 속도 향상을 달성함을 보여줍니다.
English
Effective attention modules have played a crucial role in the success of
Transformer-based large language models (LLMs), but the quadratic time and
memory complexities of these attention modules also pose a challenge when
processing long sequences. One potential solution for the long sequence problem
is to utilize distributed clusters to parallelize the computation of attention
modules across multiple devices (e.g., GPUs). However, adopting a distributed
approach inevitably introduces extra memory overheads to store local attention
results and incurs additional communication costs to aggregate local results
into global ones. In this paper, we propose a distributed attention framework
named ``BurstAttention'' to optimize memory access and communication operations
at both the global cluster and local device levels. In our experiments, we
compare BurstAttention with other competitive distributed attention solutions
for long sequence processing. The experimental results under different length
settings demonstrate that BurstAttention offers significant advantages for
processing long sequences compared with these competitive baselines, reducing
40% communication overheads and achieving 2 X speedup during training 32K
sequence length on 8 X A100.Summary
AI-Generated Summary