SageAttention3 : Attention FP4 à micro-échelle pour l'inférence et une exploration de l'entraînement en 8 bits
SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training
May 16, 2025
Auteurs: Jintao Zhang, Jia Wei, Pengle Zhang, Xiaoming Xu, Haofeng Huang, Haoxu Wang, Kai Jiang, Jun Zhu, Jianfei Chen
cs.AI
Résumé
L'efficacité de l'attention est cruciale en raison de sa complexité temporelle quadratique. Nous améliorons l'efficacité de l'attention grâce à deux contributions majeures : Premièrement, nous exploitons les nouveaux Tensor Cores FP4 des GPU Blackwell pour accélérer le calcul de l'attention. Notre implémentation atteint 1038 TOPS sur le RTX5090, ce qui représente une accélération de 5x par rapport à la version la plus rapide de FlashAttention sur le RTX5090. Les expériences montrent que notre attention FP4 peut accélérer l'inférence de divers modèles de manière plug-and-play. Deuxièmement, nous sommes les premiers à appliquer l'attention à faible précision aux tâches d'entraînement. Les travaux existants sur l'attention à faible précision, comme FlashAttention3 et SageAttention, se concentrent uniquement sur l'inférence. Cependant, l'efficacité de l'entraînement des grands modèles est également importante. Pour explorer si l'attention à faible précision peut être efficacement appliquée aux tâches d'entraînement, nous concevons une attention 8 bits précise et efficace pour la propagation avant et arrière. Les expériences indiquent que l'attention 8 bits atteint des performances sans perte dans les tâches de fine-tuning, mais présente une convergence plus lente dans les tâches de pré-entraînement. Le code sera disponible à l'adresse https://github.com/thu-ml/SageAttention.
English
The efficiency of attention is important due to its quadratic time
complexity. We enhance the efficiency of attention through two key
contributions: First, we leverage the new FP4 Tensor Cores in Blackwell GPUs to
accelerate attention computation. Our implementation achieves 1038 TOPS on
RTX5090, which is a 5x speedup over the fastest FlashAttention on RTX5090.
Experiments show that our FP4 attention can accelerate inference of various
models in a plug-and-play way. Second, we pioneer low-bit attention to training
tasks. Existing low-bit attention works like FlashAttention3 and SageAttention
focus only on inference. However, the efficiency of training large models is
also important. To explore whether low-bit attention can be effectively applied
to training tasks, we design an accurate and efficient 8-bit attention for both
forward and backward propagation. Experiments indicate that 8-bit attention
achieves lossless performance in fine-tuning tasks but exhibits slower
convergence in pretraining tasks. The code will be available at
https://github.com/thu-ml/SageAttention.Summary
AI-Generated Summary