Diffusion vidéo accélérée grâce à une attention parcimonieuse entraînable
Faster Video Diffusion with Trainable Sparse Attention
May 19, 2025
Auteurs: Peiyuan Zhang, Haofeng Huang, Yongqi Chen, Will Lin, Zhengzhong Liu, Ion Stoica, Eric P. Xing, Hao Zhang
cs.AI
Résumé
La mise à l'échelle des transformateurs de diffusion vidéo (DiTs) est limitée par leur attention 3D quadratique, bien que la majeure partie de la masse attentionnelle se concentre sur un petit sous-ensemble de positions. Nous transformons cette observation en VSA, une attention sparse entraînable et efficace en termes de matériel, qui remplace l'attention complète à la fois pendant l'entraînement et l'inférence. Dans VSA, une étape grossière légère regroupe les tokens en tuiles et identifie les tokens critiques à fort poids ; une étape fine calcule l'attention au niveau des tokens uniquement à l'intérieur de ces tuiles, en respectant une disposition de calcul par blocs pour garantir une efficacité matérielle. Cela conduit à un noyau différentiable unique qui s'entraîne de bout en bout, ne nécessite aucun profilage post-hoc et maintient 85\% de l'MFU de FlashAttention3. Nous effectuons une large série d'études d'ablation et d'expériences de lois d'échelle en pré-entraînant des DiTs de 60M à 1,4B paramètres. VSA atteint un point de Pareto qui réduit les FLOPS d'entraînement par 2,53 fois sans perte de qualité de diffusion. L'adaptation du modèle open-source Wan-2.1 accélère le temps d'attention par 6 fois et réduit le temps de génération de bout en bout de 31s à 18s avec une qualité comparable. Ces résultats établissent l'attention sparse entraînable comme une alternative pratique à l'attention complète et un catalyseur clé pour la mise à l'échelle ultérieure des modèles de diffusion vidéo.
English
Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D
attention, even though most of the attention mass concentrates on a small
subset of positions. We turn this observation into VSA, a trainable,
hardware-efficient sparse attention that replaces full attention at both
training and inference. In VSA, a lightweight coarse stage pools tokens into
tiles and identifies high-weight critical tokens; a fine stage computes
token-level attention only inside those tiles subjecting to block computing
layout to ensure hard efficiency. This leads to a single differentiable kernel
that trains end-to-end, requires no post-hoc profiling, and sustains 85\% of
FlashAttention3 MFU. We perform a large sweep of ablation studies and
scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA
reaches a Pareto point that cuts training FLOPS by 2.53times with no drop in
diffusion loss. Retrofitting the open-source Wan-2.1 model speeds up attention
time by 6times and lowers end-to-end generation time from 31s to 18s with
comparable quality. These results establish trainable sparse attention as a
practical alternative to full attention and a key enabler for further scaling
of video diffusion models.Summary
AI-Generated Summary