Snellere Video Diffusie met Trainbare Sparse Attention
Faster Video Diffusion with Trainable Sparse Attention
May 19, 2025
Auteurs: Peiyuan Zhang, Haofeng Huang, Yongqi Chen, Will Lin, Zhengzhong Liu, Ion Stoica, Eric P. Xing, Hao Zhang
cs.AI
Samenvatting
Het schalen van video-diffusietransformers (DiTs) wordt beperkt door hun kwadratische 3D-attentie, ook al concentreert het grootste deel van de aandacht zich op een kleine subset van posities. We vertalen deze observatie naar VSA, een trainbare, hardware-efficiënte sparse attention die volledige aandacht vervangt tijdens zowel training als inferentie. In VSA groepeert een lichtgewicht grove fase tokens in tegels en identificeert kritieke tokens met een hoog gewicht; een fijne fase berekent token-level aandacht alleen binnen die tegels, onderworpen aan een blokcomputing-layout om harde efficiëntie te garanderen. Dit resulteert in een enkel differentieerbaar kernel dat end-to-end traint, geen post-hoc profilering vereist en 85\% van de FlashAttention3 MFU behoudt. We voeren een grote reeks ablatiestudies en schaalwetexperimenten uit door DiTs te pretrainen van 60M tot 1,4B parameters. VSA bereikt een Pareto-punt dat de trainings-FLOPS met 2,53 keer vermindert zonder verlies in diffusieverlies. Het retrofitten van het open-source Wan-2.1-model versnelt de aandachtstijd met 6 keer en verlaagt de end-to-end generatietijd van 31s naar 18s met vergelijkbare kwaliteit. Deze resultaten vestigen trainbare sparse attention als een praktisch alternatief voor volledige aandacht en een belangrijke enabler voor verdere schaling van video-diffusiemodellen.
English
Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D
attention, even though most of the attention mass concentrates on a small
subset of positions. We turn this observation into VSA, a trainable,
hardware-efficient sparse attention that replaces full attention at both
training and inference. In VSA, a lightweight coarse stage pools tokens into
tiles and identifies high-weight critical tokens; a fine stage computes
token-level attention only inside those tiles subjecting to block computing
layout to ensure hard efficiency. This leads to a single differentiable kernel
that trains end-to-end, requires no post-hoc profiling, and sustains 85\% of
FlashAttention3 MFU. We perform a large sweep of ablation studies and
scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA
reaches a Pareto point that cuts training FLOPS by 2.53times with no drop in
diffusion loss. Retrofitting the open-source Wan-2.1 model speeds up attention
time by 6times and lowers end-to-end generation time from 31s to 18s with
comparable quality. These results establish trainable sparse attention as a
practical alternative to full attention and a key enabler for further scaling
of video diffusion models.Summary
AI-Generated Summary