Entrenamiento eficiente de modelos de lenguaje de contexto largo mediante la desagregación de la atención central
Efficient Long-context Language Model Training by Core Attention Disaggregation
October 20, 2025
Autores: Yonghao Zhuang, Junda Chen, Bo Pang, Yi Gu, Yibo Zhu, Yimin Jiang, Ion Stoica, Eric Xing, Hao Zhang
cs.AI
Resumen
Presentamos la desagregación de atención central (CAD, por sus siglas en inglés), una técnica que mejora el entrenamiento de modelos de lenguaje de gran contexto al desacoplar el cálculo de la atención central, softmax(QK^T)V, del resto del modelo y ejecutarlo en un grupo separado de dispositivos. En los sistemas existentes, la atención central se coloca junto con otras capas; en contextos largos, su crecimiento computacional cuadrático en comparación con el crecimiento casi lineal de otros componentes provoca desequilibrios de carga y retrasos en los grupos paralelos de datos y tuberías. CAD se basa en dos observaciones. Primero, la atención central no tiene estado: no tiene parámetros entrenables y solo datos transitorios mínimos, por lo que el equilibrio se reduce a la programación de tareas limitadas por el cálculo. Segundo, es componible: los núcleos de atención modernos mantienen una alta eficiencia al procesar lotes fusionados de fragmentos a nivel de token con longitudes arbitrarias. CAD divide la atención central en tareas a nivel de token y las distribuye a servidores de atención dedicados, que reagrupan dinámicamente las tareas para igualar el cálculo sin sacrificar la eficiencia del núcleo. Implementamos CAD en un sistema llamado DistCA, que utiliza un esquema de ejecución ping-pong para superponer completamente la comunicación con el cálculo y la ejecución en el lugar en los servidores de atención para reducir el uso de memoria. En 512 GPUs H200 y longitudes de contexto de hasta 512k tokens, DistCA mejora el rendimiento de entrenamiento de extremo a extremo hasta 1.35x, elimina los retrasos en los grupos paralelos de datos y tuberías, y logra un equilibrio casi perfecto de cálculo y memoria.
English
We present core attention disaggregation (CAD), a technique that improves
long-context large language model training by decoupling the core attention
computation, softmax(QK^T)V, from the rest of the model and executing it on a
separate pool of devices. In existing systems, core attention is colocated with
other layers; at long context lengths, its quadratic compute growth compared to
the near-linear growth of other components causes load imbalance and stragglers
across data and pipeline parallel groups. CAD is enabled by two observations.
First, core attention is stateless: it has no trainable parameters and only
minimal transient data, so balancing reduces to scheduling compute-bound tasks.
Second, it is composable: modern attention kernels retain high efficiency when
processing fused batches of token-level shards with arbitrary lengths. CAD
partitions core attention into token-level tasks and dispatches them to
dedicated attention servers, which dynamically rebatch tasks to equalize
compute without sacrificing kernel efficiency. We implement CAD in a system
called DistCA, which uses a ping-pong execution scheme to fully overlap
communication with computation and in-place execution on attention servers to
reduce memory use. On 512 H200 GPUs and context lengths up to 512k tokens,
DistCA improves end-to-end training throughput by up to 1.35x, eliminates data
and pipeline parallel stragglers, and achieves near-perfect compute and memory
balance.