Modèle de valeur fusionné pour l'alignement de diffusion
Stitched Value Model for Diffusion Alignment
May 19, 2026
Auteurs: Hyojun Go, Hyungjin Chung, Prune Truong, Goutam Bhat, Li Mi, Zhaochong An, Zixiang Zhao, Dominik Narnhofer, Serge Belongie, Federico Tombari, Konrad Schindler
cs.AI
Résumé
Pour une utilisation pratique, les modèles génératifs basés sur la diffusion ou le flux doivent être alignés sur des récompenses spécifiques à la tâche, telles que la fidélité à la consigne ou la préférence esthétique. Cet alignement est difficile car la récompense est définie pour des images de sortie propres, mais la procédure d’alignement nécessite des estimations de la fonction de valeur sur des latents intermédiaires bruités. Les méthodes existantes recourent à des approximations de type Tweedie ou Monte Carlo, faisant un compromis entre le biais de l’estimateur et le coût de calcul : les estimations de Tweedie sont efficaces mais biaisées, tandis que celles de Monte Carlo sont plus précises mais nécessitent des déploiements coûteux. Une alternative naturelle serait une fonction de valeur apprise, mais il reste une question ouverte de savoir comment entraîner efficacement un modèle de valeur robuste et général, spécifiquement pour les latents bruités. Nous proposons ici StitchVM, un cadre d’assemblage de modèles qui transfère efficacement les modèles de récompense pré-entraînés pour des images propres au régime des latents bruités. StitchVM part d’un modèle de récompense existant, tronqué dans l’espace pixel, et y attache un backbone de diffusion figé comme tête. Du modèle dans l’espace pixel, l’hybride résultant conserve une capacité de récompense robuste et soigneusement pré-entraînée ; du backbone de diffusion, il hérite de sa capacité native à traiter les latents bruités. La procédure d’assemblage est exceptionnellement légère : par exemple, assembler et affiner CLIP ViT-L et SD 3.5 Medium ne prend que 10 heures GPU. En élevant des modèles de récompense puissants de l’espace pixel à l’espace latent, StitchVM ouvre un nouveau style d’alignement par diffusion : au lieu d’une approximation approximative mais coûteuse par échantillon de la fonction de valeur, la fonction correcte pour les latents réels et bruités est construite une fois puis amortie sur de nombreux échantillons et itérations. Nous montrons que cette approche apporte des améliorations dans un large éventail de méthodes de guidage et de post-entraînement en aval : DPS devient 3,2 fois plus rapide tout en réduisant de moitié la mémoire GPU de pointe, et DiffusionNFT devient 2,3 fois plus rapide.
English
For practical use, diffusion- or flow-based generative models must be aligned with task-specific rewards, such as prompt fidelity or aesthetic preference. That alignment is challenging because the reward is defined for clean output images, but the alignment procedure requires value function estimates at noisy intermediate latents. Existing methods resort to Tweedie-style or Monte Carlo approximations, trading off estimator bias against computational cost: Tweedie estimates are efficient but biased, while Monte Carlo estimates are more accurate but require expensive rollouts. A natural alternative would be a learned value function, but it remains an open question how to effectively train a strong and general value model specifically for noisy latents. Here, we propose StitchVM, a model stitching framework that efficiently transfers reward models pretrained for clean images to the noisy latent regime. StitchVM starts from an existing, truncated pixel-space reward model and attaches a frozen diffusion backbone to it as its head. From the pixel-space model, the resulting hybrid retains a carefully pretrained, robust reward capability; from the diffusion backbone, it inherits its native ability to handle noisy latents. The stitching procedure is exceptionally lightweight, e.g., stitching and finetuning CLIP ViT-L and SD 3.5 Medium takes only 10 GPU-hours. By lifting powerful pixel-space reward models to latent space, StitchVM opens up a new style of diffusion alignment: instead of rough, yet costly per-sample approximation of the value function, the correct function for the actual, noisy latents is constructed once and then amortized over many samples and iterations. We show that this approach yields improvements across a broad range of downstream steering and post-training methods: DPS becomes 3.2times faster while halving peak GPU memory, and DiffusionNFT becomes 2.3times faster.