Ottimizzazione Iterativa delle Preferenze di Ragionamento
Iterative Reasoning Preference Optimization
April 30, 2024
Autori: Richard Yuanzhe Pang, Weizhe Yuan, Kyunghyun Cho, He He, Sainbayar Sukhbaatar, Jason Weston
cs.AI
Abstract
I metodi di ottimizzazione iterativa delle preferenze hanno recentemente dimostrato di funzionare bene per compiti generali di tuning delle istruzioni, ma tipicamente apportano pochi miglioramenti nei compiti di ragionamento (Yuan et al., 2024, Chen et al., 2024). In questo lavoro sviluppiamo un approccio iterativo che ottimizza la preferenza tra candidati generati di tipo Chain-of-Thought (CoT) ottimizzando per i passaggi di ragionamento vincenti rispetto a quelli perdenti che portano alla risposta corretta. Addestriamo utilizzando una funzione di perdita DPO modificata (Rafailov et al., 2023) con un termine aggiuntivo di log-verosimiglianza negativa, che riteniamo cruciale. Mostriamo che il ragionamento migliora attraverso iterazioni ripetute di questo schema. Pur basandoci esclusivamente sugli esempi nel set di addestramento, il nostro approccio porta a un aumento dell'accuratezza per Llama-2-70B-Chat dal 55,6% all'81,6% su GSM8K (e all'88,7% con il voto a maggioranza su 32 campioni), dal 12,5% al 20,8% su MATH e dal 77,8% all'86,7% su ARC-Challenge, superando altri modelli basati su Llama-2 che non si avvalgono di dataset aggiuntivi.
English
Iterative preference optimization methods have recently been shown to perform
well for general instruction tuning tasks, but typically make little
improvement on reasoning tasks (Yuan et al., 2024, Chen et al., 2024). In this
work we develop an iterative approach that optimizes the preference between
competing generated Chain-of-Thought (CoT) candidates by optimizing for winning
vs. losing reasoning steps that lead to the correct answer. We train using a
modified DPO loss (Rafailov et al., 2023) with an additional negative
log-likelihood term, which we find to be crucial. We show reasoning improves
across repeated iterations of this scheme. While only relying on examples in
the training set, our approach results in increasing accuracy for
Llama-2-70B-Chat from 55.6% to 81.6% on GSM8K (and 88.7% with majority voting
out of 32 samples), from 12.5% to 20.8% on MATH, and from 77.8% to 86.7% on
ARC-Challenge, which outperforms other Llama-2-based models not relying on
additionally sourced datasets.