ChatPaper.aiChatPaper

基於多輪迭代偏好學習的數學智能體構建方法

Building Math Agents with Multi-Turn Iterative Preference Learning

September 4, 2024
作者: Wei Xiong, Chengshuai Shi, Jiaming Shen, Aviv Rosenberg, Zhen Qin, Daniele Calandriello, Misha Khalman, Rishabh Joshi, Bilal Piot, Mohammad Saleh, Chi Jin, Tong Zhang, Tianqi Liu
cs.AI

摘要

近期研究表明,通過整合程式碼解譯器等外部工具並採用多輪思維鏈推理,可有效提升大型語言模型的數學問題求解能力。現有方法主要聚焦於合成資料生成與監督式微調,而本文則研究互補性的直接偏好學習方法以進一步提升模型效能。然而,現有的直接偏好學習演算法最初是為單輪對話任務設計,未能充分應對工具整合型數學推理任務所需的多輪推理與外部工具整合之複雜性。為填補此空白,我們提出專為此場景設計的多輪直接偏好學習框架,該框架利用程式碼解譯器的回饋訊號並最佳化軌跡層級的偏好選擇,具體實現包含多輪DPO與多輪KTO兩種方案。我們透過使用GSM8K和MATH資料集的增強提示集對多種語言模型進行訓練,驗證了該框架的有效性:經監督式微調的Gemma-1.1-it-7B模型在GSM8K上的準確率從77.5%提升至83.9%,在MATH上從46.1%提升至51.2%;同樣地,Gemma-2-it-9B模型在GSM8K上從84.1%提升至86.3%,在MATH上從51.0%提升至54.5%。
English
Recent studies have shown that large language models' (LLMs) mathematical problem-solving capabilities can be enhanced by integrating external tools, such as code interpreters, and employing multi-turn Chain-of-Thought (CoT) reasoning. While current methods focus on synthetic data generation and Supervised Fine-Tuning (SFT), this paper studies the complementary direct preference learning approach to further improve model performance. However, existing direct preference learning algorithms are originally designed for the single-turn chat task, and do not fully address the complexities of multi-turn reasoning and external tool integration required for tool-integrated mathematical reasoning tasks. To fill in this gap, we introduce a multi-turn direct preference learning framework, tailored for this context, that leverages feedback from code interpreters and optimizes trajectory-level preferences. This framework includes multi-turn DPO and multi-turn KTO as specific implementations. The effectiveness of our framework is validated through training of various language models using an augmented prompt set from the GSM8K and MATH datasets. Our results demonstrate substantial improvements: a supervised fine-tuned Gemma-1.1-it-7B model's performance increased from 77.5% to 83.9% on GSM8K and from 46.1% to 51.2% on MATH. Similarly, a Gemma-2-it-9B model improved from 84.1% to 86.3% on GSM8K and from 51.0% to 54.5% on MATH.
PDF162November 14, 2024