通過可訓練稀疏注意力實現更快的視頻擴散
Faster Video Diffusion with Trainable Sparse Attention
May 19, 2025
作者: Peiyuan Zhang, Haofeng Huang, Yongqi Chen, Will Lin, Zhengzhong Liu, Ion Stoica, Eric P. Xing, Hao Zhang
cs.AI
摘要
擴展視頻擴散變換器(DiTs)受到其二次方三維注意力機制的限制,儘管大部分注意力質量集中在少數位置上。我們將這一觀察轉化為VSA,一種可訓練且硬件高效的稀疏注意力機制,在訓練和推理階段均取代了全注意力機制。在VSA中,一個輕量級的粗粒度階段將令牌池化為瓦片並識別出高權重的關鍵令牌;細粒度階段僅在這些瓦片內部計算令牌級別的注意力,並遵循塊計算佈局以確保硬件效率。這形成了一個單一的可微分內核,能夠端到端訓練,無需事後分析,並維持了FlashAttention3 85%的MFU(矩陣浮點運算利用率)。我們通過從6000萬到14億參數預訓練DiTs,進行了大範圍的消融研究和擴展定律實驗。VSA達到了一個帕累托點,將訓練FLOPS減少了2.53倍,而擴散損失沒有下降。對開源的Wan-2.1模型進行改造後,注意力時間加快了6倍,端到端生成時間從31秒降至18秒,且質量相當。這些結果確立了可訓練稀疏注意力作為全注意力機制的實用替代方案,並成為進一步擴展視頻擴散模型的關鍵推動力。
English
Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D
attention, even though most of the attention mass concentrates on a small
subset of positions. We turn this observation into VSA, a trainable,
hardware-efficient sparse attention that replaces full attention at both
training and inference. In VSA, a lightweight coarse stage pools tokens into
tiles and identifies high-weight critical tokens; a fine stage computes
token-level attention only inside those tiles subjecting to block computing
layout to ensure hard efficiency. This leads to a single differentiable kernel
that trains end-to-end, requires no post-hoc profiling, and sustains 85\% of
FlashAttention3 MFU. We perform a large sweep of ablation studies and
scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA
reaches a Pareto point that cuts training FLOPS by 2.53times with no drop in
diffusion loss. Retrofitting the open-source Wan-2.1 model speeds up attention
time by 6times and lowers end-to-end generation time from 31s to 18s with
comparable quality. These results establish trainable sparse attention as a
practical alternative to full attention and a key enabler for further scaling
of video diffusion models.Summary
AI-Generated Summary