ChatPaper.aiChatPaper

Difusión de Vídeo más Rápida con Atención Dispersa Entrenable

Faster Video Diffusion with Trainable Sparse Attention

May 19, 2025
Autores: Peiyuan Zhang, Haofeng Huang, Yongqi Chen, Will Lin, Zhengzhong Liu, Ion Stoica, Eric P. Xing, Hao Zhang
cs.AI

Resumen

La escalabilidad de los transformadores de difusión de video (DiTs) se ve limitada por su atención 3D cuadrática, a pesar de que la mayor parte de la masa de atención se concentra en un subconjunto pequeño de posiciones. Convertimos esta observación en VSA, una atención dispersa eficiente en hardware y entrenable que reemplaza la atención completa tanto en el entrenamiento como en la inferencia. En VSA, una etapa ligera de agrupación (coarse stage) agrupa los tokens en bloques e identifica los tokens críticos de mayor peso; una etapa detallada (fine stage) calcula la atención a nivel de token solo dentro de esos bloques, siguiendo un diseño de computación por bloques para garantizar eficiencia dura. Esto da lugar a un núcleo diferenciable único que se entrena de extremo a extremo, no requiere perfilado posterior y mantiene el 85\% de la MFU de FlashAttention3. Realizamos un amplio barrido de estudios de ablación y experimentos de leyes de escalabilidad preentrenando DiTs desde 60M hasta 1.4B parámetros. VSA alcanza un punto de Pareto que reduce los FLOPS de entrenamiento en 2.53 veces sin pérdida en la pérdida de difusión. La adaptación del modelo de código abierto Wan-2.1 acelera el tiempo de atención en 6 veces y reduce el tiempo de generación de extremo a extremo de 31s a 18s con calidad comparable. Estos resultados establecen la atención dispersa entrenable como una alternativa práctica a la atención completa y un habilitador clave para seguir escalando los modelos de difusión de video.
English
Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D attention, even though most of the attention mass concentrates on a small subset of positions. We turn this observation into VSA, a trainable, hardware-efficient sparse attention that replaces full attention at both training and inference. In VSA, a lightweight coarse stage pools tokens into tiles and identifies high-weight critical tokens; a fine stage computes token-level attention only inside those tiles subjecting to block computing layout to ensure hard efficiency. This leads to a single differentiable kernel that trains end-to-end, requires no post-hoc profiling, and sustains 85\% of FlashAttention3 MFU. We perform a large sweep of ablation studies and scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA reaches a Pareto point that cuts training FLOPS by 2.53times with no drop in diffusion loss. Retrofitting the open-source Wan-2.1 model speeds up attention time by 6times and lowers end-to-end generation time from 31s to 18s with comparable quality. These results establish trainable sparse attention as a practical alternative to full attention and a key enabler for further scaling of video diffusion models.

Summary

AI-Generated Summary

PDF291May 20, 2025