ChatPaper.aiChatPaper

Difusão de Vídeo Mais Rápida com Atenção Esparsa Treinável

Faster Video Diffusion with Trainable Sparse Attention

May 19, 2025
Autores: Peiyuan Zhang, Haofeng Huang, Yongqi Chen, Will Lin, Zhengzhong Liu, Ion Stoica, Eric P. Xing, Hao Zhang
cs.AI

Resumo

A escalabilidade dos transformadores de difusão de vídeo (DiTs) é limitada por sua atenção 3D quadrática, mesmo que a maior parte da massa de atenção se concentre em um pequeno subconjunto de posições. Transformamos essa observação em VSA, uma atenção esparsa treinável e eficiente em hardware que substitui a atenção completa tanto no treinamento quanto na inferência. No VSA, um estágio leve e grosseiro agrupa tokens em blocos e identifica tokens críticos de alto peso; um estágio fino calcula a atenção em nível de token apenas dentro desses blocos, sujeito a um layout de computação em bloco para garantir eficiência rígida. Isso resulta em um kernel diferenciável único que treina de ponta a ponta, não requer perfilamento pós-treino e mantém 85\% da MFU do FlashAttention3. Realizamos uma ampla varredura de estudos de ablação e experimentos de leis de escalonamento ao pré-treinar DiTs com parâmetros variando de 60M a 1,4B. O VSA atinge um ponto de Pareto que reduz os FLOPS de treinamento em 2,53 vezes sem queda na perda de difusão. A adaptação do modelo de código aberto Wan-2.1 acelera o tempo de atenção em 6 vezes e reduz o tempo de geração de ponta a ponta de 31s para 18s com qualidade comparável. Esses resultados estabelecem a atenção esparsa treinável como uma alternativa prática à atenção completa e um facilitador chave para a escalabilidade adicional de modelos de difusão de vídeo.
English
Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D attention, even though most of the attention mass concentrates on a small subset of positions. We turn this observation into VSA, a trainable, hardware-efficient sparse attention that replaces full attention at both training and inference. In VSA, a lightweight coarse stage pools tokens into tiles and identifies high-weight critical tokens; a fine stage computes token-level attention only inside those tiles subjecting to block computing layout to ensure hard efficiency. This leads to a single differentiable kernel that trains end-to-end, requires no post-hoc profiling, and sustains 85\% of FlashAttention3 MFU. We perform a large sweep of ablation studies and scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA reaches a Pareto point that cuts training FLOPS by 2.53times with no drop in diffusion loss. Retrofitting the open-source Wan-2.1 model speeds up attention time by 6times and lowers end-to-end generation time from 31s to 18s with comparable quality. These results establish trainable sparse attention as a practical alternative to full attention and a key enabler for further scaling of video diffusion models.
PDF373May 20, 2025