通過核心注意力分解實現高效的長上下文語言模型訓練
Efficient Long-context Language Model Training by Core Attention Disaggregation
October 20, 2025
作者: Yonghao Zhuang, Junda Chen, Bo Pang, Yi Gu, Yibo Zhu, Yimin Jiang, Ion Stoica, Eric Xing, Hao Zhang
cs.AI
摘要
我們提出了核心注意力分解(CAD)技術,這是一種通過將核心注意力計算——即softmax(QK^T)V——從模型的其他部分解耦並在單獨的設備池上執行,從而提升長上下文大語言模型訓練效率的方法。在現有系統中,核心注意力與其他層共同部署;在長上下文長度下,其計算量的二次方增長相較於其他組件的近線性增長,導致數據和流水線並行組之間出現負載不均和拖尾現象。CAD的實現基於兩點觀察:首先,核心注意力是無狀態的,它沒有可訓練參數且僅包含極少的瞬態數據,因此負載均衡簡化為對計算密集型任務的調度;其次,它是可組合的,現代注意力內核在處理任意長度的融合批次token級分片時仍能保持高效。CAD將核心注意力劃分為token級任務,並將其分派至專用的注意力服務器,這些服務器動態地重新批處理任務以均衡計算量,同時不犧牲內核效率。我們在名為DistCA的系統中實現了CAD,該系統採用乒乓執行方案,完全重疊通信與計算,並在注意力服務器上進行原地執行以減少內存使用。在512個H200 GPU和上下文長度達512k token的條件下,DistCA將端到端訓練吞吐量提升了最高1.35倍,消除了數據和流水線並行中的拖尾現象,並實現了近乎完美的計算與內存平衡。
English
We present core attention disaggregation (CAD), a technique that improves
long-context large language model training by decoupling the core attention
computation, softmax(QK^T)V, from the rest of the model and executing it on a
separate pool of devices. In existing systems, core attention is colocated with
other layers; at long context lengths, its quadratic compute growth compared to
the near-linear growth of other components causes load imbalance and stragglers
across data and pipeline parallel groups. CAD is enabled by two observations.
First, core attention is stateless: it has no trainable parameters and only
minimal transient data, so balancing reduces to scheduling compute-bound tasks.
Second, it is composable: modern attention kernels retain high efficiency when
processing fused batches of token-level shards with arbitrary lengths. CAD
partitions core attention into token-level tasks and dispatches them to
dedicated attention servers, which dynamically rebatch tasks to equalize
compute without sacrificing kernel efficiency. We implement CAD in a system
called DistCA, which uses a ping-pong execution scheme to fully overlap
communication with computation and in-place execution on attention servers to
reduce memory use. On 512 H200 GPUs and context lengths up to 512k tokens,
DistCA improves end-to-end training throughput by up to 1.35x, eliminates data
and pipeline parallel stragglers, and achieves near-perfect compute and memory
balance.