ChatPaper.aiChatPaper

Iteratieve Redeneervoorkeuroptimalisatie

Iterative Reasoning Preference Optimization

April 30, 2024
Auteurs: Richard Yuanzhe Pang, Weizhe Yuan, Kyunghyun Cho, He He, Sainbayar Sukhbaatar, Jason Weston
cs.AI

Samenvatting

Iteratieve voorkeursoptimalisatiemethoden hebben recentelijk goede resultaten laten zien voor algemene instructieafstemmingstaken, maar leveren doorgaans weinig verbetering op voor redeneertaken (Yuan et al., 2024, Chen et al., 2024). In dit werk ontwikkelen we een iteratieve aanpak die de voorkeur optimaliseert tussen concurrerende gegenereerde Chain-of-Thought (CoT)-kandidaten door te optimaliseren voor winnende versus verliezende redeneerstappen die leiden tot het juiste antwoord. We trainen met een aangepast DPO-verlies (Rafailov et al., 2023) met een aanvullende negatieve log-waarschijnlijkheidsterm, die we cruciaal vinden. We laten zien dat het redeneren verbetert over herhaalde iteraties van dit schema. Hoewel we alleen vertrouwen op voorbeelden in de trainingsset, resulteert onze aanpak in een toenemende nauwkeurigheid voor Llama-2-70B-Chat van 55,6% naar 81,6% op GSM8K (en 88,7% met meerderheidsstemming uit 32 steekproeven), van 12,5% naar 20,8% op MATH, en van 77,8% naar 86,7% op ARC-Challenge, wat andere Llama-2-gebaseerde modellen overtreft die niet vertrouwen op aanvullende datasets.
English
Iterative preference optimization methods have recently been shown to perform well for general instruction tuning tasks, but typically make little improvement on reasoning tasks (Yuan et al., 2024, Chen et al., 2024). In this work we develop an iterative approach that optimizes the preference between competing generated Chain-of-Thought (CoT) candidates by optimizing for winning vs. losing reasoning steps that lead to the correct answer. We train using a modified DPO loss (Rafailov et al., 2023) with an additional negative log-likelihood term, which we find to be crucial. We show reasoning improves across repeated iterations of this scheme. While only relying on examples in the training set, our approach results in increasing accuracy for Llama-2-70B-Chat from 55.6% to 81.6% on GSM8K (and 88.7% with majority voting out of 32 samples), from 12.5% to 20.8% on MATH, and from 77.8% to 86.7% on ARC-Challenge, which outperforms other Llama-2-based models not relying on additionally sourced datasets.
PDF496February 8, 2026