ChatPaper.aiChatPaper

反復的推論選好最適化

Iterative Reasoning Preference Optimization

April 30, 2024
著者: Richard Yuanzhe Pang, Weizhe Yuan, Kyunghyun Cho, He He, Sainbayar Sukhbaatar, Jason Weston
cs.AI

要旨

反復的な選好最適化手法は、一般的な指示チューニングタスクにおいて良好な性能を示すことが最近明らかになりましたが、推論タスクではほとんど改善が見られないことが一般的です(Yuan et al., 2024; Chen et al., 2024)。本研究では、正解に至る勝ち負けの推論ステップを最適化することで、競合するChain-of-Thought(CoT)候補間の選好を最適化する反復的アプローチを開発します。我々は、修正されたDPO損失(Rafailov et al., 2023)に追加の負の対数尤度項を用いて学習を行い、これが重要であることを確認しました。このスキームを繰り返し適用することで、推論能力が向上することを示します。訓練セットの例のみに依存しながら、我々のアプローチにより、Llama-2-70B-Chatの精度はGSM8Kで55.6%から81.6%(32サンプルの多数決では88.7%)、MATHで12.5%から20.8%、ARC-Challengeで77.8%から86.7%に向上し、追加のデータセットに依存しない他のLlama-2ベースのモデルを上回りました。
English
Iterative preference optimization methods have recently been shown to perform well for general instruction tuning tasks, but typically make little improvement on reasoning tasks (Yuan et al., 2024, Chen et al., 2024). In this work we develop an iterative approach that optimizes the preference between competing generated Chain-of-Thought (CoT) candidates by optimizing for winning vs. losing reasoning steps that lead to the correct answer. We train using a modified DPO loss (Rafailov et al., 2023) with an additional negative log-likelihood term, which we find to be crucial. We show reasoning improves across repeated iterations of this scheme. While only relying on examples in the training set, our approach results in increasing accuracy for Llama-2-70B-Chat from 55.6% to 81.6% on GSM8K (and 88.7% with majority voting out of 32 samples), from 12.5% to 20.8% on MATH, and from 77.8% to 86.7% on ARC-Challenge, which outperforms other Llama-2-based models not relying on additionally sourced datasets.

Summary

AI-Generated Summary

PDF506December 8, 2024