ChatPaper.aiChatPaper

SageAttention3: Atención de FP4 en Microscala para Inferencia y una Exploración del Entrenamiento en 8 Bits

SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training

May 16, 2025
Autores: Jintao Zhang, Jia Wei, Pengle Zhang, Xiaoming Xu, Haofeng Huang, Haoxu Wang, Kai Jiang, Jun Zhu, Jianfei Chen
cs.AI

Resumen

La eficiencia de la atención es importante debido a su complejidad temporal cuadrática. Mejoramos la eficiencia de la atención a través de dos contribuciones clave: En primer lugar, aprovechamos los nuevos Tensor Cores FP4 en las GPU Blackwell para acelerar el cálculo de la atención. Nuestra implementación alcanza 1038 TOPS en la RTX5090, lo que representa una aceleración de 5x sobre la implementación más rápida de FlashAttention en la RTX5090. Los experimentos muestran que nuestra atención FP4 puede acelerar la inferencia de varios modelos de manera plug-and-play. En segundo lugar, somos pioneros en aplicar la atención de bajo bit a tareas de entrenamiento. Trabajos existentes sobre atención de bajo bit, como FlashAttention3 y SageAttention, se centran únicamente en la inferencia. Sin embargo, la eficiencia en el entrenamiento de modelos grandes también es crucial. Para explorar si la atención de bajo bit puede aplicarse efectivamente a tareas de entrenamiento, diseñamos una atención de 8 bits precisa y eficiente tanto para la propagación hacia adelante como hacia atrás. Los experimentos indican que la atención de 8 bits logra un rendimiento sin pérdidas en tareas de ajuste fino, pero muestra una convergencia más lenta en tareas de preentrenamiento. El código estará disponible en https://github.com/thu-ml/SageAttention.
English
The efficiency of attention is important due to its quadratic time complexity. We enhance the efficiency of attention through two key contributions: First, we leverage the new FP4 Tensor Cores in Blackwell GPUs to accelerate attention computation. Our implementation achieves 1038 TOPS on RTX5090, which is a 5x speedup over the fastest FlashAttention on RTX5090. Experiments show that our FP4 attention can accelerate inference of various models in a plug-and-play way. Second, we pioneer low-bit attention to training tasks. Existing low-bit attention works like FlashAttention3 and SageAttention focus only on inference. However, the efficiency of training large models is also important. To explore whether low-bit attention can be effectively applied to training tasks, we design an accurate and efficient 8-bit attention for both forward and backward propagation. Experiments indicate that 8-bit attention achieves lossless performance in fine-tuning tasks but exhibits slower convergence in pretraining tasks. The code will be available at https://github.com/thu-ml/SageAttention.
PDF766May 21, 2025