自我改进的健壮偏好优化
Self-Improving Robust Preference Optimization
June 3, 2024
作者: Eugene Choi, Arash Ahmadian, Matthieu Geist, Oilvier Pietquin, Mohammad Gheshlaghi Azar
cs.AI
摘要
在线和离线的RLHF方法,如PPO和DPO,在将人工智能与人类偏好对齐方面取得了极大成功。尽管它们取得了成功,但现有方法存在一个根本问题,即它们的最优解高度依赖任务(即对分布外(OOD)任务不具有鲁棒性)。在这里,我们通过提出自我改进鲁棒偏好优化SRPO来解决这一挑战,这是一个实用且在数学上合理的离线RLHF框架,完全能够适应任务的变化。SRPO的关键思想是将从人类偏好中学习的问题视为一个自我改进的过程,可以用一个旨在通过对抗方式联合优化自我改进策略和生成策略的极小极大目标来进行数学表达。这种优化问题的解决方案独立于训练任务,因此对其变化具有鲁棒性。然后,我们展示了这一目标可以重新表达为一种非对抗性的离线损失形式,可以在规模上使用标准监督优化技术进行优化,而无需奖励模型和在线推断。我们展示了SRPO在AI胜率(WR)对人类(GOLD)完成情况的效果。特别是,在OOD XSUM数据集上评估SRPO后,经过5次自我修订后,其胜率达到90%,比著名的DPO明显高出15%。
English
Both online and offline RLHF methods such as PPO and DPO have been extremely
successful in aligning AI with human preferences. Despite their success, the
existing methods suffer from a fundamental problem that their optimal solution
is highly task-dependent (i.e., not robust to out-of-distribution (OOD) tasks).
Here we address this challenge by proposing Self-Improving Robust Preference
Optimization SRPO, a practical and mathematically principled offline RLHF
framework that is completely robust to the changes in the task. The key idea of
SRPO is to cast the problem of learning from human preferences as a
self-improvement process, which can be mathematically expressed in terms of a
min-max objective that aims at joint optimization of self-improvement policy
and the generative policy in an adversarial fashion. The solution for this
optimization problem is independent of the training task and thus it is robust
to its changes. We then show that this objective can be re-expressed in the
form of a non-adversarial offline loss which can be optimized using standard
supervised optimization techniques at scale without any need for reward model
and online inference. We show the effectiveness of SRPO in terms of AI Win-Rate
(WR) against human (GOLD) completions. In particular, when SRPO is evaluated on
the OOD XSUM dataset, it outperforms the celebrated DPO by a clear margin of
15% after 5 self-revisions, achieving WR of 90%.