SageAttention3:面向推理的微缩放FP4注意力机制及8位训练探索
SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training
May 16, 2025
作者: Jintao Zhang, Jia Wei, Pengle Zhang, Xiaoming Xu, Haofeng Huang, Haoxu Wang, Kai Jiang, Jun Zhu, Jianfei Chen
cs.AI
摘要
注意力机制的效率至关重要,因其具有二次时间复杂度。我们通过两项关键贡献提升了注意力机制的效率:首先,我们利用Blackwell GPU中的新型FP4 Tensor Core加速注意力计算。我们的实现在RTX5090上达到了1038 TOPS,相比RTX5090上最快的FlashAttention实现了5倍加速。实验表明,我们的FP4注意力机制能够以即插即用的方式加速多种模型的推理过程。其次,我们率先将低位宽注意力应用于训练任务。现有的低位宽注意力工作,如FlashAttention3和SageAttention,仅专注于推理阶段。然而,训练大型模型的效率同样重要。为了探索低位宽注意力能否有效应用于训练任务,我们设计了一种精确且高效的8位注意力机制,适用于前向和反向传播。实验结果显示,8位注意力在微调任务中实现了无损性能,但在预训练任务中表现出较慢的收敛速度。代码将在https://github.com/thu-ml/SageAttention 公开。
English
The efficiency of attention is important due to its quadratic time
complexity. We enhance the efficiency of attention through two key
contributions: First, we leverage the new FP4 Tensor Cores in Blackwell GPUs to
accelerate attention computation. Our implementation achieves 1038 TOPS on
RTX5090, which is a 5x speedup over the fastest FlashAttention on RTX5090.
Experiments show that our FP4 attention can accelerate inference of various
models in a plug-and-play way. Second, we pioneer low-bit attention to training
tasks. Existing low-bit attention works like FlashAttention3 and SageAttention
focus only on inference. However, the efficiency of training large models is
also important. To explore whether low-bit attention can be effectively applied
to training tasks, we design an accurate and efficient 8-bit attention for both
forward and backward propagation. Experiments indicate that 8-bit attention
achieves lossless performance in fine-tuning tasks but exhibits slower
convergence in pretraining tasks. The code will be available at
https://github.com/thu-ml/SageAttention.Summary
AI-Generated Summary