コアアテンションの分解による効率的な長文脈言語モデル訓練
Efficient Long-context Language Model Training by Core Attention Disaggregation
October 20, 2025
著者: Yonghao Zhuang, Junda Chen, Bo Pang, Yi Gu, Yibo Zhu, Yimin Jiang, Ion Stoica, Eric Xing, Hao Zhang
cs.AI
要旨
本論文では、コアアテンションの分散処理(Core Attention Disaggregation, CAD)を提案する。この技術は、コアアテンション計算であるsoftmax(QK^T)Vをモデルの他の部分から切り離し、専用のデバイスプールで実行することで、長文脈大規模言語モデルの学習を改善する。既存のシステムでは、コアアテンションは他のレイヤーと共配置されているが、長い文脈長では、他のコンポーネントのほぼ線形な計算量増加に比べて二次的に増加するコアアテンションの計算量が、データ並列およびパイプライン並列グループ間での負荷不均衡や遅延を引き起こす。CADは、2つの観察に基づいて実現されている。第一に、コアアテンションはステートレスであり、学習可能なパラメータを持たず、最小限の一時データしか持たないため、負荷分散は計算バウンドなタスクのスケジューリングに帰着する。第二に、コアアテンションは合成可能であり、現代のアテンションカーネルは、任意の長さのトークンレベルの断片を融合したバッチで処理する際にも高い効率を維持する。CADは、コアアテンションをトークンレベルのタスクに分割し、それらを専用のアテンションサーバーにディスパッチする。これらのサーバーは、カーネル効率を犠牲にすることなく、計算量を均等化するためにタスクを動的に再バッチングする。我々は、DistCAというシステムにCADを実装した。DistCAは、ピンポン実行スキームを使用して通信と計算を完全にオーバーラップさせ、アテンションサーバー上でのインプレース実行によりメモリ使用量を削減する。512台のH200 GPUと最大512kトークンの文脈長において、DistCAはエンドツーエンドの学習スループットを最大1.35倍向上させ、データ並列およびパイプライン並列の遅延を排除し、ほぼ完璧な計算およびメモリバランスを達成する。
English
We present core attention disaggregation (CAD), a technique that improves
long-context large language model training by decoupling the core attention
computation, softmax(QK^T)V, from the rest of the model and executing it on a
separate pool of devices. In existing systems, core attention is colocated with
other layers; at long context lengths, its quadratic compute growth compared to
the near-linear growth of other components causes load imbalance and stragglers
across data and pipeline parallel groups. CAD is enabled by two observations.
First, core attention is stateless: it has no trainable parameters and only
minimal transient data, so balancing reduces to scheduling compute-bound tasks.
Second, it is composable: modern attention kernels retain high efficiency when
processing fused batches of token-level shards with arbitrary lengths. CAD
partitions core attention into token-level tasks and dispatches them to
dedicated attention servers, which dynamically rebatch tasks to equalize
compute without sacrificing kernel efficiency. We implement CAD in a system
called DistCA, which uses a ping-pong execution scheme to fully overlap
communication with computation and in-place execution on attention servers to
reduce memory use. On 512 H200 GPUs and context lengths up to 512k tokens,
DistCA improves end-to-end training throughput by up to 1.35x, eliminates data
and pipeline parallel stragglers, and achieves near-perfect compute and memory
balance.