拼接價值模型用於擴散對齊
Stitched Value Model for Diffusion Alignment
May 19, 2026
作者: Hyojun Go, Hyungjin Chung, Prune Truong, Goutam Bhat, Li Mi, Zhaochong An, Zixiang Zhao, Dominik Narnhofer, Serge Belongie, Federico Tombari, Konrad Schindler
cs.AI
摘要
在实际应用中,基于扩散或流的生成模型必须与任务特定的奖励信号(如提示保真度或审美偏好)对齐。这一对齐过程颇具挑战性,因为奖励是针对干净输出图像定义的,而对齐过程需要在带有噪声的中间隐变量处进行价值函数估计。现有方法采用 Tweedie 式或蒙特卡洛近似,在估计偏差与计算代价之间进行权衡:Tweedie 估计效率高但有偏,而蒙特卡洛估计更精确但需要昂贵的轨迹展开。一种自然的替代方案是使用学习得到的价值函数,但如何针对噪声隐变量高效训练一个强大且通用的价值模型仍是一个开放问题。本文提出 StitchVM,一种模型拼接框架,能够将针对干净图像预训练的奖励模型高效迁移至噪声隐变量场景。StitchVM 从一个现成的、截断的像素空间奖励模型出发,将一个冻结的扩散骨干网络作为其头部附加其上。所得混合模型从像素空间模型继承了精心预训练的、稳健的奖励能力;从扩散骨干网络继承了其处理噪声隐变量的原生能力。拼接过程极为轻量:例如,将 CLIP ViT-L 与 SD 3.5 Medium 拼接并微调仅需 10 GPU 小时。通过将强大的像素空间奖励模型提升至隐空间,StitchVM 开辟了一种全新的扩散对齐方式:不再对每个样本进行粗略而昂贵的价值函数近似,而是为实际的噪声隐变量一次性构造正确的函数,并在众多样本和迭代次数中摊销计算成本。我们展示,该方法在多种下游引导和后训练方法中均能带来改进:DPS 速度提升 3.2 倍,同时峰值 GPU 内存减半;DiffusionNFT 速度提升 2.3 倍。
English
For practical use, diffusion- or flow-based generative models must be aligned with task-specific rewards, such as prompt fidelity or aesthetic preference. That alignment is challenging because the reward is defined for clean output images, but the alignment procedure requires value function estimates at noisy intermediate latents. Existing methods resort to Tweedie-style or Monte Carlo approximations, trading off estimator bias against computational cost: Tweedie estimates are efficient but biased, while Monte Carlo estimates are more accurate but require expensive rollouts. A natural alternative would be a learned value function, but it remains an open question how to effectively train a strong and general value model specifically for noisy latents. Here, we propose StitchVM, a model stitching framework that efficiently transfers reward models pretrained for clean images to the noisy latent regime. StitchVM starts from an existing, truncated pixel-space reward model and attaches a frozen diffusion backbone to it as its head. From the pixel-space model, the resulting hybrid retains a carefully pretrained, robust reward capability; from the diffusion backbone, it inherits its native ability to handle noisy latents. The stitching procedure is exceptionally lightweight, e.g., stitching and finetuning CLIP ViT-L and SD 3.5 Medium takes only 10 GPU-hours. By lifting powerful pixel-space reward models to latent space, StitchVM opens up a new style of diffusion alignment: instead of rough, yet costly per-sample approximation of the value function, the correct function for the actual, noisy latents is constructed once and then amortized over many samples and iterations. We show that this approach yields improvements across a broad range of downstream steering and post-training methods: DPS becomes 3.2times faster while halving peak GPU memory, and DiffusionNFT becomes 2.3times faster.