StreamBP: Retropropagación Exacta Eficiente en Memoria para el Entrenamiento de Secuencias Largas en Modelos de Lenguaje de Gran Escala
StreamBP: Memory-Efficient Exact Backpropagation for Long Sequence Training of LLMs
June 3, 2025
Autores: Qijun Luo, Mengqi Li, Lei Zhao, Xiao Li
cs.AI
Resumen
Entrenar modelos de lenguaje en datos de secuencias largas es un requisito exigente para mejorar la capacidad del modelo en tareas complejas, como el razonamiento de cadena larga. Sin embargo, a medida que la longitud de la secuencia aumenta, el costo de memoria para almacenar los valores de activación se vuelve enorme durante el proceso de Retropropagación (BP), incluso con la aplicación de la técnica de checkpointing de gradientes. Para abordar este desafío, proponemos un método de BP eficiente en memoria y exacto llamado StreamBP, que realiza una descomposición lineal de la regla de la cadena a lo largo de la dimensión de la secuencia de manera capa por capa, reduciendo significativamente el costo de memoria de los valores de activación y logits. El método propuesto es aplicable a objetivos comunes como SFT, GRPO y DPO. Desde una perspectiva de implementación, StreamBP logra menos operaciones de punto flotante (FLOPs) y una velocidad de BP más rápida al aprovechar la estructura causal del modelo de lenguaje. En comparación con el checkpointing de gradientes, StreamBP escala la longitud máxima de secuencia de BP entre 2.8 y 5.5 veces más, mientras utiliza un tiempo de BP comparable o incluso menor. Cabe destacar que la capacidad de escalado de longitud de secuencia de StreamBP puede transferirse directamente al escalado del tamaño del lote para acelerar el entrenamiento. Además, desarrollamos una versión distribuida de StreamBP eficiente en comunicación para apoyar efectivamente el entrenamiento multi-GPU y ampliar su aplicabilidad. Nuestro código puede integrarse fácilmente en la tubería de entrenamiento de cualquier modelo transformador y está disponible en https://github.com/Ledzy/StreamBP.
English
Training language models on long sequence data is a demanding requirement for
enhancing the model's capability on complex tasks, e.g., long-chain reasoning.
However, as the sequence length scales up, the memory cost for storing
activation values becomes huge during the Backpropagation (BP) process, even
with the application of gradient checkpointing technique. To tackle this
challenge, we propose a memory-efficient and exact BP method called StreamBP,
which performs a linear decomposition of the chain rule along the sequence
dimension in a layer-wise manner, significantly reducing the memory cost of
activation values and logits. The proposed method is applicable to common
objectives such as SFT, GRPO, and DPO. From an implementation perspective,
StreamBP achieves less computational FLOPs and faster BP speed by leveraging
the causal structure of the language model. Compared to gradient checkpointing,
StreamBP scales up the maximum sequence length of BP by 2.8-5.5 times larger,
while using comparable or even less BP time. Note that StreamBP's sequence
length scaling ability can be directly transferred to batch size scaling for
accelerating training. We further develop a communication-efficient distributed
StreamBP to effectively support multi-GPU training and broaden its
applicability. Our code can be easily integrated into the training pipeline of
any transformer models and is available at https://github.com/Ledzy/StreamBP.