PRDP : Prédiction de la Différence de Récompense Proximale pour le Réglage Fin à Grande Échelle des Récompenses dans les Modèles de Diffusion
PRDP: Proximal Reward Difference Prediction for Large-Scale Reward Finetuning of Diffusion Models
February 13, 2024
Auteurs: Fei Deng, Qifei Wang, Wei Wei, Matthias Grundmann, Tingbo Hou
cs.AI
Résumé
Le réglage par récompense s'est imposé comme une approche prometteuse pour aligner les modèles de base avec les objectifs en aval. Des succès remarquables ont été obtenus dans le domaine du langage en utilisant l'apprentissage par renforcement (RL) pour maximiser les récompenses reflétant les préférences humaines. Cependant, dans le domaine de la vision, les méthodes existantes de réglage par récompense basées sur le RL sont limitées par leur instabilité lors de l'entraînement à grande échelle, les rendant incapables de généraliser à des prompts complexes et inédits. Dans cet article, nous proposons la Prédiction de Différence de Récompense Proximale (PRDP), permettant pour la première fois un réglage par récompense stable en boîte noire pour les modèles de diffusion sur des ensembles de données de prompts à grande échelle contenant plus de 100 000 prompts. Notre innovation clé est l'objectif de Prédiction de Différence de Récompense (RDP) qui possède la même solution optimale que l'objectif du RL tout en bénéficiant d'une meilleure stabilité d'entraînement. Concrètement, l'objectif RDP est un objectif de régression supervisée qui consiste à demander au modèle de diffusion de prédire la différence de récompense entre des paires d'images générées à partir de leurs trajectoires de débruitage. Nous prouvons théoriquement que le modèle de diffusion qui obtient une prédiction parfaite de la différence de récompense est exactement le maximiseur de l'objectif du RL. Nous développons en outre un algorithme en ligne avec des mises à jour proximales pour optimiser de manière stable l'objectif RDP. Dans les expériences, nous démontrons que PRDP peut égaler la capacité de maximisation des récompenses des méthodes bien établies basées sur le RL lors d'un entraînement à petite échelle. De plus, grâce à un entraînement à grande échelle sur des prompts textuels issus du Human Preference Dataset v2 et du dataset Pick-a-Pic v1, PRDP atteint une qualité de génération supérieure sur un ensemble diversifié de prompts complexes et inédits, alors que les méthodes basées sur le RL échouent complètement.
English
Reward finetuning has emerged as a promising approach to aligning foundation
models with downstream objectives. Remarkable success has been achieved in the
language domain by using reinforcement learning (RL) to maximize rewards that
reflect human preference. However, in the vision domain, existing RL-based
reward finetuning methods are limited by their instability in large-scale
training, rendering them incapable of generalizing to complex, unseen prompts.
In this paper, we propose Proximal Reward Difference Prediction (PRDP),
enabling stable black-box reward finetuning for diffusion models for the first
time on large-scale prompt datasets with over 100K prompts. Our key innovation
is the Reward Difference Prediction (RDP) objective that has the same optimal
solution as the RL objective while enjoying better training stability.
Specifically, the RDP objective is a supervised regression objective that tasks
the diffusion model with predicting the reward difference of generated image
pairs from their denoising trajectories. We theoretically prove that the
diffusion model that obtains perfect reward difference prediction is exactly
the maximizer of the RL objective. We further develop an online algorithm with
proximal updates to stably optimize the RDP objective. In experiments, we
demonstrate that PRDP can match the reward maximization ability of
well-established RL-based methods in small-scale training. Furthermore, through
large-scale training on text prompts from the Human Preference Dataset v2 and
the Pick-a-Pic v1 dataset, PRDP achieves superior generation quality on a
diverse set of complex, unseen prompts whereas RL-based methods completely
fail.Summary
AI-Generated Summary