ChatPaper.aiChatPaper

SageAttention3: Microscaling FP4 per l'attenzione nell'inferenza e un'esplorazione dell'addestramento a 8 bit

SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training

May 16, 2025
Autori: Jintao Zhang, Jia Wei, Pengle Zhang, Xiaoming Xu, Haofeng Huang, Haoxu Wang, Kai Jiang, Jun Zhu, Jianfei Chen
cs.AI

Abstract

L'efficienza dell'attenzione è cruciale a causa della sua complessità temporale quadratica. Miglioriamo l'efficienza dell'attenzione attraverso due contributi chiave: in primo luogo, sfruttiamo i nuovi Tensor Core FP4 nelle GPU Blackwell per accelerare il calcolo dell'attenzione. La nostra implementazione raggiunge 1038 TOPS su RTX5090, ottenendo un incremento di velocità di 5x rispetto alla più veloce FlashAttention su RTX5090. Gli esperimenti dimostrano che la nostra attenzione FP4 può accelerare l'inferenza di vari modelli in modo plug-and-play. In secondo luogo, siamo pionieri nell'applicazione dell'attenzione a basso bit ai task di addestramento. Le attuali soluzioni di attenzione a basso bit, come FlashAttention3 e SageAttention, si concentrano solo sull'inferenza. Tuttavia, l'efficienza nell'addestramento di modelli di grandi dimensioni è altrettanto importante. Per esplorare se l'attenzione a basso bit possa essere efficacemente applicata ai task di addestramento, progettiamo un'attenzione a 8 bit precisa ed efficiente sia per la propagazione in avanti che per quella all'indietro. Gli esperimenti indicano che l'attenzione a 8 bit raggiunge prestazioni senza perdite nei task di fine-tuning, ma mostra una convergenza più lenta nei task di pre-addestramento. Il codice sarà disponibile su https://github.com/thu-ml/SageAttention.
English
The efficiency of attention is important due to its quadratic time complexity. We enhance the efficiency of attention through two key contributions: First, we leverage the new FP4 Tensor Cores in Blackwell GPUs to accelerate attention computation. Our implementation achieves 1038 TOPS on RTX5090, which is a 5x speedup over the fastest FlashAttention on RTX5090. Experiments show that our FP4 attention can accelerate inference of various models in a plug-and-play way. Second, we pioneer low-bit attention to training tasks. Existing low-bit attention works like FlashAttention3 and SageAttention focus only on inference. However, the efficiency of training large models is also important. To explore whether low-bit attention can be effectively applied to training tasks, we design an accurate and efficient 8-bit attention for both forward and backward propagation. Experiments indicate that 8-bit attention achieves lossless performance in fine-tuning tasks but exhibits slower convergence in pretraining tasks. The code will be available at https://github.com/thu-ml/SageAttention.
PDF766May 21, 2025