通过多轮迭代偏好学习构建数学智能体
Building Math Agents with Multi-Turn Iterative Preference Learning
September 4, 2024
作者: Wei Xiong, Chengshuai Shi, Jiaming Shen, Aviv Rosenberg, Zhen Qin, Daniele Calandriello, Misha Khalman, Rishabh Joshi, Bilal Piot, Mohammad Saleh, Chi Jin, Tong Zhang, Tianqi Liu
cs.AI
摘要
近期研究表明,通过整合外部工具(如代码解释器)并采用多轮链式思维(CoT)推理,大型语言模型(LLMs)的数学问题解决能力可以得到显著提升。尽管现有方法侧重于合成数据生成和监督微调(SFT),本文则探讨了互补的直接偏好学习途径,以进一步提升模型性能。然而,现有的直接偏好学习算法最初是为单轮对话任务设计的,未能充分应对工具集成数学推理任务所需的多轮推理和外部工具整合的复杂性。为填补这一空白,我们引入了一种多轮直接偏好学习框架,专为此情境定制,该框架利用代码解释器的反馈并优化轨迹级别的偏好。此框架具体包括多轮DPO和多轮KTO两种实现方式。通过在GSM8K和MATH数据集上使用增强提示集训练多种语言模型,我们验证了该框架的有效性。结果显示,经过监督微调的Gemma-1.1-it-7B模型在GSM8K上的准确率从77.5%提升至83.9%,在MATH上从46.1%提升至51.2%。同样,Gemma-2-it-9B模型在GSM8K上的表现从84.1%提升至86.3%,在MATH上从51.0%提升至54.5%。
English
Recent studies have shown that large language models' (LLMs) mathematical
problem-solving capabilities can be enhanced by integrating external tools,
such as code interpreters, and employing multi-turn Chain-of-Thought (CoT)
reasoning. While current methods focus on synthetic data generation and
Supervised Fine-Tuning (SFT), this paper studies the complementary direct
preference learning approach to further improve model performance. However,
existing direct preference learning algorithms are originally designed for the
single-turn chat task, and do not fully address the complexities of multi-turn
reasoning and external tool integration required for tool-integrated
mathematical reasoning tasks. To fill in this gap, we introduce a multi-turn
direct preference learning framework, tailored for this context, that leverages
feedback from code interpreters and optimizes trajectory-level preferences.
This framework includes multi-turn DPO and multi-turn KTO as specific
implementations. The effectiveness of our framework is validated through
training of various language models using an augmented prompt set from the
GSM8K and MATH datasets. Our results demonstrate substantial improvements: a
supervised fine-tuned Gemma-1.1-it-7B model's performance increased from 77.5%
to 83.9% on GSM8K and from 46.1% to 51.2% on MATH. Similarly, a Gemma-2-it-9B
model improved from 84.1% to 86.3% on GSM8K and from 51.0% to 54.5% on MATH.