SageAttention3: 推論のためのFP4アテンションのマイクロスケーリングと8ビットトレーニングの探求
SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training
May 16, 2025
著者: Jintao Zhang, Jia Wei, Pengle Zhang, Xiaoming Xu, Haofeng Huang, Haoxu Wang, Kai Jiang, Jun Zhu, Jianfei Chen
cs.AI
要旨
注意機構の効率性は、その二次的な時間複雑性のため重要である。本研究では、2つの主要な貢献を通じて注意機構の効率性を向上させる。第一に、Blackwell GPUの新しいFP4 Tensor Coreを活用して注意計算を高速化する。我々の実装はRTX5090上で1038 TOPSを達成し、RTX5090上で最速のFlashAttentionと比較して5倍の高速化を実現した。実験により、我々のFP4注意機構が様々なモデルの推論をプラグアンドプレイ方式で加速できることが示された。第二に、低ビット注意機構を学習タスクに初めて適用する。既存の低ビット注意機構であるFlashAttention3やSageAttentionは推論にのみ焦点を当てている。しかし、大規模モデルの学習効率も重要である。低ビット注意機構が学習タスクに効果的に適用できるかどうかを探るため、我々は順伝播と逆伝播の両方に対応する正確かつ効率的な8ビット注意機構を設計した。実験結果から、8ビット注意機構はファインチューニングタスクでは性能の損失なく機能するが、事前学習タスクでは収束が遅いことが示された。コードはhttps://github.com/thu-ml/SageAttentionで公開予定である。
English
The efficiency of attention is important due to its quadratic time
complexity. We enhance the efficiency of attention through two key
contributions: First, we leverage the new FP4 Tensor Cores in Blackwell GPUs to
accelerate attention computation. Our implementation achieves 1038 TOPS on
RTX5090, which is a 5x speedup over the fastest FlashAttention on RTX5090.
Experiments show that our FP4 attention can accelerate inference of various
models in a plug-and-play way. Second, we pioneer low-bit attention to training
tasks. Existing low-bit attention works like FlashAttention3 and SageAttention
focus only on inference. However, the efficiency of training large models is
also important. To explore whether low-bit attention can be effectively applied
to training tasks, we design an accurate and efficient 8-bit attention for both
forward and backward propagation. Experiments indicate that 8-bit attention
achieves lossless performance in fine-tuning tasks but exhibits slower
convergence in pretraining tasks. The code will be available at
https://github.com/thu-ml/SageAttention.Summary
AI-Generated Summary