Stitched-Wertmodell für Diffusionsalignment
Stitched Value Model for Diffusion Alignment
May 19, 2026
Autoren: Hyojun Go, Hyungjin Chung, Prune Truong, Goutam Bhat, Li Mi, Zhaochong An, Zixiang Zhao, Dominik Narnhofer, Serge Belongie, Federico Tombari, Konrad Schindler
cs.AI
Zusammenfassung
Für den praktischen Einsatz müssen Diffusions- oder flussbasierte generative Modelle an aufgabenspezifische Belohnungen wie Prompt-Treue oder ästhetische Präferenz angepasst werden. Diese Anpassung ist herausfordernd, da die Belohnung für saubere Ausgabebilder definiert ist, das Anpassungsverfahren jedoch Schätzungen der Wertefunktion auf verrauschten, intermediären Latents erfordert. Bestehende Methoden greifen auf Tweedie-artige oder Monte-Carlo-Approximationen zurück und tauschen dabei Schätzerverzerrung gegen Rechenaufwand ein: Tweedie-Schätzungen sind effizient, aber verzerrt, während Monte-Carlo-Schätzungen genauer sind, jedoch aufwändige Rollouts erfordern. Eine natürliche Alternative wäre eine gelernte Wertefunktion, aber es bleibt eine offene Frage, wie man ein starkes und allgemeines Wertemodell speziell für verrauschte Latents effektiv trainieren kann.
Hier schlagen wir StitchVM vor, ein Modell-Stitching-Framework, das für saubere Bilder vortrainierte Belohnungsmodelle effizient in den Bereich verrauschter Latents überführt. StitchVM geht von einem bestehenden, abgeschnittenen Pixelraum-Belohnungsmodell aus und fügt einen eingefrorenen Diffusions-Backbone als dessen Kopf an. Vom Pixelraum-Modell behält der resultierende Hybrid eine sorgfältig vortrainierte, robuste Belohnungsfähigkeit; vom Diffusions-Backbone erbt er dessen natürliche Fähigkeit, mit verrauschten Latents umzugehen. Das Stitching-Verfahren ist außergewöhnlich leichtgewichtig; beispielsweise dauert das Stitching und Feintuning von CLIP ViT-L und SD 3.5 Medium nur 10 GPU-Stunden.
Durch die Übertragung leistungsfähiger Pixelraum-Belohnungsmodelle in den Latent-Raum eröffnet StitchVM einen neuen Stil der Diffusionsanpassung: Anstatt einer groben, aber aufwändigen stichprobenweisen Approximation der Wertefunktion wird die korrekte Funktion für die tatsächlichen, verrauschten Latents einmal konstruiert und dann über viele Stichproben und Iterationen amortisiert. Wir zeigen, dass dieser Ansatz bei einer breiten Palette nachgelagerter Steuerungs- und Nachtrainingsmethoden Verbesserungen bringt: DPS wird 3,2-mal schneller bei halbierter GPU-Speicherspitze, und DiffusionNFT wird 2,3-mal schneller.
English
For practical use, diffusion- or flow-based generative models must be aligned with task-specific rewards, such as prompt fidelity or aesthetic preference. That alignment is challenging because the reward is defined for clean output images, but the alignment procedure requires value function estimates at noisy intermediate latents. Existing methods resort to Tweedie-style or Monte Carlo approximations, trading off estimator bias against computational cost: Tweedie estimates are efficient but biased, while Monte Carlo estimates are more accurate but require expensive rollouts. A natural alternative would be a learned value function, but it remains an open question how to effectively train a strong and general value model specifically for noisy latents. Here, we propose StitchVM, a model stitching framework that efficiently transfers reward models pretrained for clean images to the noisy latent regime. StitchVM starts from an existing, truncated pixel-space reward model and attaches a frozen diffusion backbone to it as its head. From the pixel-space model, the resulting hybrid retains a carefully pretrained, robust reward capability; from the diffusion backbone, it inherits its native ability to handle noisy latents. The stitching procedure is exceptionally lightweight, e.g., stitching and finetuning CLIP ViT-L and SD 3.5 Medium takes only 10 GPU-hours. By lifting powerful pixel-space reward models to latent space, StitchVM opens up a new style of diffusion alignment: instead of rough, yet costly per-sample approximation of the value function, the correct function for the actual, noisy latents is constructed once and then amortized over many samples and iterations. We show that this approach yields improvements across a broad range of downstream steering and post-training methods: DPS becomes 3.2times faster while halving peak GPU memory, and DiffusionNFT becomes 2.3times faster.