Remasquage des modèles de diffusion discrets avec mise à l'échelle au moment de l'inférence
Remasking Discrete Diffusion Models with Inference-Time Scaling
March 1, 2025
Auteurs: Guanghan Wang, Yair Schiff, Subham Sekhar Sahoo, Volodymyr Kuleshov
cs.AI
Résumé
Une partie du succès des modèles de diffusion provient de leur capacité à effectuer un raffinement itératif, c'est-à-dire à corriger de manière répétée les sorties pendant la génération. Cependant, la diffusion discrète masquée moderne ne possède pas cette capacité : lorsqu'un token est généré, il ne peut plus être mis à jour, même s'il introduit une erreur. Ici, nous abordons cette limitation en introduisant l'échantillonneur ReMDM (Remasking Diffusion Model), une méthode qui peut être appliquée de manière rigoureuse à des modèles de diffusion masquée pré-entraînés et qui est dérivée d'un modèle de diffusion discrète avec un processus de retour personnalisé de remasquage. Plus intéressant encore, ReMDM confère à la diffusion discrète une forme de mise à l'échelle du calcul au moment de l'inférence. En augmentant le nombre d'étapes d'échantillonnage, ReMDM génère des sorties en langage naturel qui approchent la qualité des modèles autorégressifs, tandis que lorsque le budget de calcul est limité, ReMDM maintient mieux la qualité. ReMDM améliore également la qualité des échantillons des modèles de diffusion masquée pour les images discrétisées, et dans des domaines scientifiques tels que la conception de molécules, ReMDM facilite le guidage par diffusion et repousse la frontière de Pareto de la contrôlabilité par rapport au masquage classique et à la diffusion de bruit uniforme. Nous fournissons le code ainsi qu'un article de blog sur la page du projet : https://remdm.github.io.
English
Part of the success of diffusion models stems from their ability to perform
iterative refinement, i.e., repeatedly correcting outputs during generation.
However, modern masked discrete diffusion lacks this capability: when a token
is generated, it cannot be updated again, even when it introduces an error.
Here, we address this limitation by introducing the remasking diffusion model
(ReMDM) sampler, a method that can be applied to pretrained masked diffusion
models in a principled way and that is derived from a discrete diffusion model
with a custom remasking backward process. Most interestingly, ReMDM endows
discrete diffusion with a form of inference-time compute scaling. By increasing
the number of sampling steps, ReMDM generates natural language outputs that
approach the quality of autoregressive models, whereas when the computation
budget is limited, ReMDM better maintains quality. ReMDM also improves sample
quality of masked diffusion models for discretized images, and in scientific
domains such as molecule design, ReMDM facilitates diffusion guidance and
pushes the Pareto frontier of controllability relative to classical masking and
uniform noise diffusion. We provide the code along with a blog post on the
project page: https://remdm.github.io.Summary
AI-Generated Summary