SageAttention3: Microscaling FP4 per l'attenzione nell'inferenza e un'esplorazione dell'addestramento a 8 bit
SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training
May 16, 2025
Autori: Jintao Zhang, Jia Wei, Pengle Zhang, Xiaoming Xu, Haofeng Huang, Haoxu Wang, Kai Jiang, Jun Zhu, Jianfei Chen
cs.AI
Abstract
L'efficienza dell'attenzione è cruciale a causa della sua complessità temporale quadratica. Miglioriamo l'efficienza dell'attenzione attraverso due contributi chiave: in primo luogo, sfruttiamo i nuovi Tensor Core FP4 nelle GPU Blackwell per accelerare il calcolo dell'attenzione. La nostra implementazione raggiunge 1038 TOPS su RTX5090, ottenendo un incremento di velocità di 5x rispetto alla più veloce FlashAttention su RTX5090. Gli esperimenti dimostrano che la nostra attenzione FP4 può accelerare l'inferenza di vari modelli in modo plug-and-play. In secondo luogo, siamo pionieri nell'applicazione dell'attenzione a basso bit ai task di addestramento. Le attuali soluzioni di attenzione a basso bit, come FlashAttention3 e SageAttention, si concentrano solo sull'inferenza. Tuttavia, l'efficienza nell'addestramento di modelli di grandi dimensioni è altrettanto importante. Per esplorare se l'attenzione a basso bit possa essere efficacemente applicata ai task di addestramento, progettiamo un'attenzione a 8 bit precisa ed efficiente sia per la propagazione in avanti che per quella all'indietro. Gli esperimenti indicano che l'attenzione a 8 bit raggiunge prestazioni senza perdite nei task di fine-tuning, ma mostra una convergenza più lenta nei task di pre-addestramento. Il codice sarà disponibile su https://github.com/thu-ml/SageAttention.
English
The efficiency of attention is important due to its quadratic time
complexity. We enhance the efficiency of attention through two key
contributions: First, we leverage the new FP4 Tensor Cores in Blackwell GPUs to
accelerate attention computation. Our implementation achieves 1038 TOPS on
RTX5090, which is a 5x speedup over the fastest FlashAttention on RTX5090.
Experiments show that our FP4 attention can accelerate inference of various
models in a plug-and-play way. Second, we pioneer low-bit attention to training
tasks. Existing low-bit attention works like FlashAttention3 and SageAttention
focus only on inference. However, the efficiency of training large models is
also important. To explore whether low-bit attention can be effectively applied
to training tasks, we design an accurate and efficient 8-bit attention for both
forward and backward propagation. Experiments indicate that 8-bit attention
achieves lossless performance in fine-tuning tasks but exhibits slower
convergence in pretraining tasks. The code will be available at
https://github.com/thu-ml/SageAttention.