ChatPaper.aiChatPaper

다중 턴 반복적 선호 학습을 통한 수학 에이전트 구축

Building Math Agents with Multi-Turn Iterative Preference Learning

September 4, 2024
저자: Wei Xiong, Chengshuai Shi, Jiaming Shen, Aviv Rosenberg, Zhen Qin, Daniele Calandriello, Misha Khalman, Rishabh Joshi, Bilal Piot, Mohammad Saleh, Chi Jin, Tong Zhang, Tianqi Liu
cs.AI

초록

최근 연구에 따르면 코드 해석기와 같은 외부 도구를 통합하고 다중 턴 사고 연쇄(CoT) 추론을 적용하면 대규모 언어 모델(LLM)의 수학 문제 해결 능력을 향상시킬 수 있다. 현재의 방법론이 합성 데이터 생성과 지도 미세 조정(SFT)에 초점을 맞추고 있는 가운데, 본 논문은 모델 성능을 더욱 개선하기 위한 보완적 접근법으로 직접 선호도 학습을 연구한다. 그러나 기존의 직접 선호도 학습 알고리즘은 단일 턴 채팅 작업을 위해 설계되어, 도구 통합 수학 추론 과제에 필요한 다중 턴 추론과 외부 도구 통합의 복잡성을 완전히 해결하지 못한다. 이러한 격차를 메우기 위해, 우리는 코드 해석기의 피드백을 활용하고 궤적 수준 선호도를 최적화하는, 이 맥락에 맞게 설계된 다중 턴 직접 선호도 학습 프레임워크를 소개한다. 이 프레임워크는 다중 턴 DPO와 다중 턴 KTO를 구체적인 구현 방식으로 포함한다. 우리 프레임워크의 효과는 GSM8K 및 MATH 데이터셋의 확장된 프롬프트 세트를 사용하여 다양한 언어 모델을 학습시킴으로써 검증되었다. 그 결과, 지도 미세 조정된 Gemma-1.1-it-7B 모델의 성능은 GSM8K에서 77.5%에서 83.9%로, MATH에서 46.1%에서 51.2%로 크게 향상되었다. 유사하게, Gemma-2-it-9B 모델은 GSM8K에서 84.1%에서 86.3%로, MATH에서 51.0%에서 54.5%로 개선되었다.
English
Recent studies have shown that large language models' (LLMs) mathematical problem-solving capabilities can be enhanced by integrating external tools, such as code interpreters, and employing multi-turn Chain-of-Thought (CoT) reasoning. While current methods focus on synthetic data generation and Supervised Fine-Tuning (SFT), this paper studies the complementary direct preference learning approach to further improve model performance. However, existing direct preference learning algorithms are originally designed for the single-turn chat task, and do not fully address the complexities of multi-turn reasoning and external tool integration required for tool-integrated mathematical reasoning tasks. To fill in this gap, we introduce a multi-turn direct preference learning framework, tailored for this context, that leverages feedback from code interpreters and optimizes trajectory-level preferences. This framework includes multi-turn DPO and multi-turn KTO as specific implementations. The effectiveness of our framework is validated through training of various language models using an augmented prompt set from the GSM8K and MATH datasets. Our results demonstrate substantial improvements: a supervised fine-tuned Gemma-1.1-it-7B model's performance increased from 77.5% to 83.9% on GSM8K and from 46.1% to 51.2% on MATH. Similarly, a Gemma-2-it-9B model improved from 84.1% to 86.3% on GSM8K and from 51.0% to 54.5% on MATH.
PDF162November 14, 2024