Difusión de Vídeo más Rápida con Atención Dispersa Entrenable
Faster Video Diffusion with Trainable Sparse Attention
May 19, 2025
Autores: Peiyuan Zhang, Haofeng Huang, Yongqi Chen, Will Lin, Zhengzhong Liu, Ion Stoica, Eric P. Xing, Hao Zhang
cs.AI
Resumen
La escalabilidad de los transformadores de difusión de video (DiTs) se ve limitada por su atención 3D cuadrática, a pesar de que la mayor parte de la masa de atención se concentra en un subconjunto pequeño de posiciones. Convertimos esta observación en VSA, una atención dispersa eficiente en hardware y entrenable que reemplaza la atención completa tanto en el entrenamiento como en la inferencia. En VSA, una etapa ligera de agrupación (coarse stage) agrupa los tokens en bloques e identifica los tokens críticos de mayor peso; una etapa detallada (fine stage) calcula la atención a nivel de token solo dentro de esos bloques, siguiendo un diseño de computación por bloques para garantizar eficiencia dura. Esto da lugar a un núcleo diferenciable único que se entrena de extremo a extremo, no requiere perfilado posterior y mantiene el 85\% de la MFU de FlashAttention3. Realizamos un amplio barrido de estudios de ablación y experimentos de leyes de escalabilidad preentrenando DiTs desde 60M hasta 1.4B parámetros. VSA alcanza un punto de Pareto que reduce los FLOPS de entrenamiento en 2.53 veces sin pérdida en la pérdida de difusión. La adaptación del modelo de código abierto Wan-2.1 acelera el tiempo de atención en 6 veces y reduce el tiempo de generación de extremo a extremo de 31s a 18s con calidad comparable. Estos resultados establecen la atención dispersa entrenable como una alternativa práctica a la atención completa y un habilitador clave para seguir escalando los modelos de difusión de video.
English
Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D
attention, even though most of the attention mass concentrates on a small
subset of positions. We turn this observation into VSA, a trainable,
hardware-efficient sparse attention that replaces full attention at both
training and inference. In VSA, a lightweight coarse stage pools tokens into
tiles and identifies high-weight critical tokens; a fine stage computes
token-level attention only inside those tiles subjecting to block computing
layout to ensure hard efficiency. This leads to a single differentiable kernel
that trains end-to-end, requires no post-hoc profiling, and sustains 85\% of
FlashAttention3 MFU. We perform a large sweep of ablation studies and
scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA
reaches a Pareto point that cuts training FLOPS by 2.53times with no drop in
diffusion loss. Retrofitting the open-source Wan-2.1 model speeds up attention
time by 6times and lowers end-to-end generation time from 31s to 18s with
comparable quality. These results establish trainable sparse attention as a
practical alternative to full attention and a key enabler for further scaling
of video diffusion models.Summary
AI-Generated Summary