MPDiT : Architecture Transformer Globale-Locale Multi-Patch pour un Modèle de Flow Matching et de Diffusion Efficace
MPDiT: Multi-Patch Global-to-Local Transformer Architecture For Efficient Flow Matching and Diffusion Model
March 27, 2026
Auteurs: Quan Dao, Dimitris Metaxas
cs.AI
Résumé
Les architectures de type Transformer, particulièrement les Transformers de Diffusion (DiTs), sont largement utilisées dans les modèles de diffusion et d'appariement de flux en raison de leurs performances supérieures comparées aux UNets convolutionnels. Cependant, la conception isotrope des DiTs traite le même nombre de tokens patchifiés dans chaque bloc, entraînant un calcul relativement lourd pendant l'entraînement. Dans ce travail, nous introduisons une conception de transformer multi-patch où les premiers blocs opèrent sur des patches plus larges pour capturer le contexte global grossier, tandis que les blocs suivants utilisent des patches plus petits pour affiner les détails locaux. Cette conception hiérarchique permet de réduire le coût computationnel jusqu'à 50\% en GFLOPs tout en atteignant de bonnes performances génératives. De plus, nous proposons également des conceptions améliorées pour les embeddings temporels et de classe qui accélèrent la convergence de l'entraînement. Des expériences approfondies sur le jeu de données ImageNet démontrent l'efficacité de nos choix architecturaux. Le code est disponible à l'adresse https://github.com/quandao10/MPDiT.
English
Transformer architectures, particularly Diffusion Transformers (DiTs), have become widely used in diffusion and flow-matching models due to their strong performance compared to convolutional UNets. However, the isotropic design of DiTs processes the same number of patchified tokens in every block, leading to relatively heavy computation during training process. In this work, we introduce a multi-patch transformer design in which early blocks operate on larger patches to capture coarse global context, while later blocks use smaller patches to refine local details. This hierarchical design could reduces computational cost by up to 50\% in GFLOPs while achieving good generative performance. In addition, we also propose improved designs for time and class embeddings that accelerate training convergence. Extensive experiments on the ImageNet dataset demonstrate the effectiveness of our architectural choices. Code is released at https://github.com/quandao10/MPDiT