SageAttention3:微縮放FP4注意力機制於推理應用及8位元訓練之探索
SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training
May 16, 2025
作者: Jintao Zhang, Jia Wei, Pengle Zhang, Xiaoming Xu, Haofeng Huang, Haoxu Wang, Kai Jiang, Jun Zhu, Jianfei Chen
cs.AI
摘要
由於注意力機制具有二次方時間複雜度,其效率至關重要。我們通過兩項關鍵貢獻來提升注意力機制的效率:首先,我們利用Blackwell GPU中的新型FP4 Tensor Core來加速注意力計算。我們的實現在RTX5090上達到了1038 TOPS,相比於RTX5090上最快的FlashAttention實現,速度提升了5倍。實驗表明,我們的FP4注意力能夠以即插即用的方式加速多種模型的推理過程。其次,我們率先將低比特注意力應用於訓練任務。現有的低比特注意力工作,如FlashAttention3和SageAttention,僅專注於推理。然而,訓練大型模型的效率同樣重要。為了探索低比特注意力是否能夠有效應用於訓練任務,我們設計了一種精確且高效的8比特注意力機制,適用於前向和反向傳播。實驗結果表明,8比特注意力在微調任務中實現了無損性能,但在預訓練任務中表現出較慢的收斂速度。代碼將在https://github.com/thu-ml/SageAttention 上提供。
English
The efficiency of attention is important due to its quadratic time
complexity. We enhance the efficiency of attention through two key
contributions: First, we leverage the new FP4 Tensor Cores in Blackwell GPUs to
accelerate attention computation. Our implementation achieves 1038 TOPS on
RTX5090, which is a 5x speedup over the fastest FlashAttention on RTX5090.
Experiments show that our FP4 attention can accelerate inference of various
models in a plug-and-play way. Second, we pioneer low-bit attention to training
tasks. Existing low-bit attention works like FlashAttention3 and SageAttention
focus only on inference. However, the efficiency of training large models is
also important. To explore whether low-bit attention can be effectively applied
to training tasks, we design an accurate and efficient 8-bit attention for both
forward and backward propagation. Experiments indicate that 8-bit attention
achieves lossless performance in fine-tuning tasks but exhibits slower
convergence in pretraining tasks. The code will be available at
https://github.com/thu-ml/SageAttention.Summary
AI-Generated Summary