ChatPaper.aiChatPaper

迭代推理偏好優化

Iterative Reasoning Preference Optimization

April 30, 2024
作者: Richard Yuanzhe Pang, Weizhe Yuan, Kyunghyun Cho, He He, Sainbayar Sukhbaatar, Jason Weston
cs.AI

摘要

最近,已經顯示迭代式偏好優化方法在一般指導調整任務中表現良好,但通常對推理任務幾乎沒有改進(Yuan等,2024年,Chen等,2024年)。在這項工作中,我們開發了一種迭代方法,通過優化競爭生成的“思維鏈”(CoT)候選者之間的偏好,來優化導致正確答案的勝利與失敗推理步驟。我們使用修改後的DPO損失(Rafailov等,2023年)進行訓練,並加入了一個額外的負對數概似項,我們發現這是至關重要的。我們展示了這種方案的重複迭代過程中推理能力的改善。儘管僅依賴於訓練集中的示例,我們的方法使得在GSM8K上Llama-2-70B-Chat的準確率從55.6%提高到81.6%(在32個樣本中以多數投票達到88.7%),在MATH上從12.5%提高到20.8%,在ARC-Challenge上從77.8%提高到86.7%,這超越了其他不依賴額外來源數據集的基於Llama-2的模型。
English
Iterative preference optimization methods have recently been shown to perform well for general instruction tuning tasks, but typically make little improvement on reasoning tasks (Yuan et al., 2024, Chen et al., 2024). In this work we develop an iterative approach that optimizes the preference between competing generated Chain-of-Thought (CoT) candidates by optimizing for winning vs. losing reasoning steps that lead to the correct answer. We train using a modified DPO loss (Rafailov et al., 2023) with an additional negative log-likelihood term, which we find to be crucial. We show reasoning improves across repeated iterations of this scheme. While only relying on examples in the training set, our approach results in increasing accuracy for Llama-2-70B-Chat from 55.6% to 81.6% on GSM8K (and 88.7% with majority voting out of 32 samples), from 12.5% to 20.8% on MATH, and from 77.8% to 86.7% on ARC-Challenge, which outperforms other Llama-2-based models not relying on additionally sourced datasets.

Summary

AI-Generated Summary

PDF506December 8, 2024