ChatPaper.aiChatPaper

Schnellere Video-Diffusion mit trainierbarer sparser Aufmerksamkeit

Faster Video Diffusion with Trainable Sparse Attention

May 19, 2025
Autoren: Peiyuan Zhang, Haofeng Huang, Yongqi Chen, Will Lin, Zhengzhong Liu, Ion Stoica, Eric P. Xing, Hao Zhang
cs.AI

Zusammenfassung

Die Skalierung von Video-Diffusion-Transformatoren (DiTs) wird durch ihre quadratische 3D-Aufmerksamkeit begrenzt, obwohl sich der Großteil der Aufmerksamkeitsmasse auf eine kleine Teilmenge von Positionen konzentriert. Wir nutzen diese Beobachtung für VSA, eine trainierbare, hardware-effiziente sparse Aufmerksamkeit, die die vollständige Aufmerksamkeit sowohl während des Trainings als auch der Inferenz ersetzt. In VSA aggregiert eine leichte Grobphase Tokens zu Kacheln und identifiziert hochgewichtige kritische Tokens; eine Feinphase berechnet die Token-Level-Aufmerksamkeit nur innerhalb dieser Kacheln, wobei eine Block-Computing-Struktur verwendet wird, um Hard-Effizienz zu gewährleisten. Dies führt zu einem einzigen differenzierbaren Kernel, der end-to-end trainiert, keine nachträgliche Profilerstellung erfordert und 85\% der FlashAttention3-MFU beibehält. Wir führen umfangreiche Ablationsstudien und Skalierungsgesetz-Experimente durch, indem wir DiTs mit 60M bis 1,4B Parametern vortrainieren. VSA erreicht einen Pareto-Punkt, der die Trainings-FLOPS um das 2,53-fache reduziert, ohne den Diffusionsverlust zu erhöhen. Die Nachrüstung des Open-Source-Modells Wan-2.1 beschleunigt die Aufmerksamkeitszeit um das 6-fache und verkürzt die end-to-end-Generierungszeit von 31s auf 18s bei vergleichbarer Qualität. Diese Ergebnisse etablieren trainierbare sparse Aufmerksamkeit als praktische Alternative zur vollständigen Aufmerksamkeit und als Schlüsseltechnologie für die weitere Skalierung von Video-Diffusionsmodellen.
English
Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D attention, even though most of the attention mass concentrates on a small subset of positions. We turn this observation into VSA, a trainable, hardware-efficient sparse attention that replaces full attention at both training and inference. In VSA, a lightweight coarse stage pools tokens into tiles and identifies high-weight critical tokens; a fine stage computes token-level attention only inside those tiles subjecting to block computing layout to ensure hard efficiency. This leads to a single differentiable kernel that trains end-to-end, requires no post-hoc profiling, and sustains 85\% of FlashAttention3 MFU. We perform a large sweep of ablation studies and scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA reaches a Pareto point that cuts training FLOPS by 2.53times with no drop in diffusion loss. Retrofitting the open-source Wan-2.1 model speeds up attention time by 6times and lowers end-to-end generation time from 31s to 18s with comparable quality. These results establish trainable sparse attention as a practical alternative to full attention and a key enabler for further scaling of video diffusion models.

Summary

AI-Generated Summary

PDF291May 20, 2025