ChatPaper.aiChatPaper

Optimisation Itérative des Préférences de Raisonnement

Iterative Reasoning Preference Optimization

April 30, 2024
Auteurs: Richard Yuanzhe Pang, Weizhe Yuan, Kyunghyun Cho, He He, Sainbayar Sukhbaatar, Jason Weston
cs.AI

Résumé

Les méthodes d'optimisation itérative des préférences ont récemment démontré de bonnes performances pour les tâches générales de réglage d'instructions, mais elles apportent généralement peu d'amélioration pour les tâches de raisonnement (Yuan et al., 2024, Chen et al., 2024). Dans ce travail, nous développons une approche itérative qui optimise la préférence entre des candidats générés de type Chaîne de Pensée (CoT) en optimisant les étapes de raisonnement gagnantes par rapport aux perdantes qui mènent à la bonne réponse. Nous entraînons en utilisant une fonction de perte DPO modifiée (Rafailov et al., 2023) avec un terme supplémentaire de log-vraisemblance négative, que nous jugeons crucial. Nous montrons que le raisonnement s'améliore au fil des itérations répétées de ce schéma. Bien que ne s'appuyant que sur des exemples de l'ensemble d'entraînement, notre approche permet d'augmenter la précision de Llama-2-70B-Chat de 55,6 % à 81,6 % sur GSM8K (et 88,7 % avec un vote majoritaire sur 32 échantillons), de 12,5 % à 20,8 % sur MATH, et de 77,8 % à 86,7 % sur ARC-Challenge, surpassant ainsi d'autres modèles basés sur Llama-2 qui ne reposent pas sur des ensembles de données supplémentaires.
English
Iterative preference optimization methods have recently been shown to perform well for general instruction tuning tasks, but typically make little improvement on reasoning tasks (Yuan et al., 2024, Chen et al., 2024). In this work we develop an iterative approach that optimizes the preference between competing generated Chain-of-Thought (CoT) candidates by optimizing for winning vs. losing reasoning steps that lead to the correct answer. We train using a modified DPO loss (Rafailov et al., 2023) with an additional negative log-likelihood term, which we find to be crucial. We show reasoning improves across repeated iterations of this scheme. While only relying on examples in the training set, our approach results in increasing accuracy for Llama-2-70B-Chat from 55.6% to 81.6% on GSM8K (and 88.7% with majority voting out of 32 samples), from 12.5% to 20.8% on MATH, and from 77.8% to 86.7% on ARC-Challenge, which outperforms other Llama-2-based models not relying on additionally sourced datasets.

Summary

AI-Generated Summary

PDF506December 8, 2024