通过核心注意力分解实现高效的长上下文语言模型训练
Efficient Long-context Language Model Training by Core Attention Disaggregation
October 20, 2025
作者: Yonghao Zhuang, Junda Chen, Bo Pang, Yi Gu, Yibo Zhu, Yimin Jiang, Ion Stoica, Eric Xing, Hao Zhang
cs.AI
摘要
我们提出了核心注意力解耦(CAD)技术,该技术通过将核心注意力计算——即softmax(QK^T)V——从模型的其他部分分离出来,并在独立的设备池上执行,从而优化了长上下文大语言模型的训练。在现有系统中,核心注意力与其他层共同部署;在长上下文长度下,其计算量呈二次方增长,而其他组件则接近线性增长,这导致了数据和流水线并行组间的负载不均和拖尾现象。CAD的实现基于两点观察:首先,核心注意力是无状态的,它没有可训练参数,仅包含极少的临时数据,因此负载均衡简化为对计算密集型任务的调度;其次,它是可组合的,现代注意力内核在处理任意长度的令牌级分片融合批次时,仍能保持高效率。CAD将核心注意力划分为令牌级任务,并分派至专用的注意力服务器,这些服务器动态地重新批处理任务以均衡计算量,同时不牺牲内核效率。我们在名为DistCA的系统中实现了CAD,该系统采用乒乓执行方案,完全重叠通信与计算,并在注意力服务器上就地执行以减少内存使用。在512块H200 GPU上,上下文长度达到512k令牌时,DistCA将端到端训练吞吐量提升至多1.35倍,消除了数据和流水线并行中的拖尾现象,并实现了近乎完美的计算与内存平衡。
English
We present core attention disaggregation (CAD), a technique that improves
long-context large language model training by decoupling the core attention
computation, softmax(QK^T)V, from the rest of the model and executing it on a
separate pool of devices. In existing systems, core attention is colocated with
other layers; at long context lengths, its quadratic compute growth compared to
the near-linear growth of other components causes load imbalance and stragglers
across data and pipeline parallel groups. CAD is enabled by two observations.
First, core attention is stateless: it has no trainable parameters and only
minimal transient data, so balancing reduces to scheduling compute-bound tasks.
Second, it is composable: modern attention kernels retain high efficiency when
processing fused batches of token-level shards with arbitrary lengths. CAD
partitions core attention into token-level tasks and dispatches them to
dedicated attention servers, which dynamically rebatch tasks to equalize
compute without sacrificing kernel efficiency. We implement CAD in a system
called DistCA, which uses a ping-pong execution scheme to fully overlap
communication with computation and in-place execution on attention servers to
reduce memory use. On 512 H200 GPUs and context lengths up to 512k tokens,
DistCA improves end-to-end training throughput by up to 1.35x, eliminates data
and pipeline parallel stragglers, and achieves near-perfect compute and memory
balance.