ChatPaper.aiChatPaper

可训练稀疏注意力机制加速视频扩散

Faster Video Diffusion with Trainable Sparse Attention

May 19, 2025
作者: Peiyuan Zhang, Haofeng Huang, Yongqi Chen, Will Lin, Zhengzhong Liu, Ion Stoica, Eric P. Xing, Hao Zhang
cs.AI

摘要

视频扩散变换器(DiTs)的扩展受限于其二次方的三维注意力机制,尽管大部分注意力质量集中在少数位置上。基于这一观察,我们提出了VSA(可训练的硬件高效稀疏注意力),它在训练和推理阶段均替代了全注意力机制。在VSA中,一个轻量级的粗粒度阶段将令牌池化为区块并识别出高权重关键令牌;细粒度阶段则仅在这些区块内部计算令牌级注意力,同时遵循块计算布局以确保硬件效率。这形成了一个可端到端训练的单次可微核,无需事后分析,并保持了FlashAttention3 MFU的85%。我们通过从6000万到14亿参数的DiTs预训练,进行了广泛的消融实验和扩展定律研究。VSA达到了一个帕累托点,将训练浮点运算次数减少了2.53倍,且扩散损失无下降。对开源Wan-2.1模型进行改造后,注意力时间加速了6倍,端到端生成时间从31秒缩短至18秒,且质量相当。这些结果确立了可训练稀疏注意力作为全注意力实用替代方案的地位,并成为视频扩散模型进一步扩展的关键推动力。
English
Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D attention, even though most of the attention mass concentrates on a small subset of positions. We turn this observation into VSA, a trainable, hardware-efficient sparse attention that replaces full attention at both training and inference. In VSA, a lightweight coarse stage pools tokens into tiles and identifies high-weight critical tokens; a fine stage computes token-level attention only inside those tiles subjecting to block computing layout to ensure hard efficiency. This leads to a single differentiable kernel that trains end-to-end, requires no post-hoc profiling, and sustains 85\% of FlashAttention3 MFU. We perform a large sweep of ablation studies and scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA reaches a Pareto point that cuts training FLOPS by 2.53times with no drop in diffusion loss. Retrofitting the open-source Wan-2.1 model speeds up attention time by 6times and lowers end-to-end generation time from 31s to 18s with comparable quality. These results establish trainable sparse attention as a practical alternative to full attention and a key enabler for further scaling of video diffusion models.

Summary

AI-Generated Summary

PDF291May 20, 2025