ChatPaper.aiChatPaper

Otimização Iterativa de Preferências de Raciocínio

Iterative Reasoning Preference Optimization

April 30, 2024
Autores: Richard Yuanzhe Pang, Weizhe Yuan, Kyunghyun Cho, He He, Sainbayar Sukhbaatar, Jason Weston
cs.AI

Resumo

Métodos iterativos de otimização de preferências têm demonstrado recentemente um bom desempenho em tarefas gerais de ajuste de instruções, mas geralmente trazem pouca melhoria em tarefas de raciocínio (Yuan et al., 2024, Chen et al., 2024). Neste trabalho, desenvolvemos uma abordagem iterativa que otimiza a preferência entre candidatos gerados de Cadeia de Pensamento (CoT) concorrentes, otimizando para etapas de raciocínio vencedoras versus perdedoras que levam à resposta correta. Treinamos usando uma função de perda DPO modificada (Rafailov et al., 2023) com um termo adicional de log-verossimilhança negativa, que consideramos crucial. Mostramos que o raciocínio melhora ao longo de iterações repetidas desse esquema. Apesar de depender apenas de exemplos do conjunto de treinamento, nossa abordagem resulta em um aumento de precisão para o Llama-2-70B-Chat de 55,6% para 81,6% no GSM8K (e 88,7% com votação majoritária de 32 amostras), de 12,5% para 20,8% no MATH e de 77,8% para 86,7% no ARC-Challenge, superando outros modelos baseados no Llama-2 que não utilizam conjuntos de dados adicionais.
English
Iterative preference optimization methods have recently been shown to perform well for general instruction tuning tasks, but typically make little improvement on reasoning tasks (Yuan et al., 2024, Chen et al., 2024). In this work we develop an iterative approach that optimizes the preference between competing generated Chain-of-Thought (CoT) candidates by optimizing for winning vs. losing reasoning steps that lead to the correct answer. We train using a modified DPO loss (Rafailov et al., 2023) with an additional negative log-likelihood term, which we find to be crucial. We show reasoning improves across repeated iterations of this scheme. While only relying on examples in the training set, our approach results in increasing accuracy for Llama-2-70B-Chat from 55.6% to 81.6% on GSM8K (and 88.7% with majority voting out of 32 samples), from 12.5% to 20.8% on MATH, and from 77.8% to 86.7% on ARC-Challenge, which outperforms other Llama-2-based models not relying on additionally sourced datasets.
PDF496December 8, 2024