ChatPaper.aiChatPaper

잠재 변수 추론을 통한 사고 사슬(Chain-of-Thought) 학습

Training Chain-of-Thought via Latent-Variable Inference

November 28, 2023
저자: Du Phan, Matthew D. Hoffman, David Dohan, Sholto Douglas, Tuan Anh Le, Aaron Parisi, Pavel Sountsov, Charles Sutton, Sharad Vikram, Rif A. Saurous
cs.AI

초록

대형 언어 모델(LLM)은 "사고의 연쇄"(Chain-of-Thought, CoT) 프롬프트를 사용해 단계별로 답을 도출하도록 지시받을 때 더 정확하고 해석 가능하게 문제를 해결합니다. 또한, 특정 작업에서 LLM의 성능을 개선하기 위해 지도 학습 기반 미세 조정(supervised fine-tuning)을 사용할 수 있습니다. 이는 조정 가능한 매개변수에 대해 경사 상승법(gradient ascent)을 적용하여 레이블이 지정된 훈련 세트에서 정답의 평균 로그 가능도를 최대화하는 방식으로 이루어집니다. CoT와 지도 학습을 단순히 결합하려면 정답뿐만 아니라 그 정답에 이르는 상세한 논리(rationale)에 대한 지도도 필요합니다. 그러나 이러한 논리를 수작업으로 생성하는 것은 비용이 많이 듭니다. 대신, 우리는 CoT 프롬프트를 사용하여 정답을 생성하는 주변 로그 가능도(marginal log-likelihood)를 최대화하려는 미세 조정 전략을 제안합니다. 이는 가능한 모든 논리에 대해 근사적으로 평균을 내는 방식입니다. 핵심 과제는 정답에 조건부로 주어진 논리에 대한 사후 분포(posterior)에서 샘플링하는 것입니다. 우리는 이를 해결하기 위해 자기 학습 추론기(self-taught reasoner, STaR), 메모이제이션된 웨이크-슬립(memoized wake-sleep), 마르코프 점수 상승(Markovian score climbing), 지속적 대조 발산(persistent contrastive divergence)에서 영감을 받은 간단한 마르코프 체인 몬테 카를로(MCMC) 기대값 최대화(EM) 알고리즘을 사용합니다. 이 알고리즘은 또한 모델이 개선됨에 따라 그래디언트 추정치의 분산을 0으로 줄이는 새로운 제어 변수(control-variate) 기법을 허용합니다. 우리의 기법을 GSM8K 및 BIG-Bench Hard의 작업에 적용한 결과, 이 MCMC-EM 미세 조정 기법이 일반적으로 CoT를 사용하거나 사용하지 않은 STaR 또는 프롬프트 튜닝보다 보유된 예제에서 모델의 정확도를 더 크게 향상시키는 것을 확인했습니다.
English
Large language models (LLMs) solve problems more accurately and interpretably when instructed to work out the answer step by step using a ``chain-of-thought'' (CoT) prompt. One can also improve LLMs' performance on a specific task by supervised fine-tuning, i.e., by using gradient ascent on some tunable parameters to maximize the average log-likelihood of correct answers from a labeled training set. Naively combining CoT with supervised tuning requires supervision not just of the correct answers, but also of detailed rationales that lead to those answers; these rationales are expensive to produce by hand. Instead, we propose a fine-tuning strategy that tries to maximize the marginal log-likelihood of generating a correct answer using CoT prompting, approximately averaging over all possible rationales. The core challenge is sampling from the posterior over rationales conditioned on the correct answer; we address it using a simple Markov-chain Monte Carlo (MCMC) expectation-maximization (EM) algorithm inspired by the self-taught reasoner (STaR), memoized wake-sleep, Markovian score climbing, and persistent contrastive divergence. This algorithm also admits a novel control-variate technique that drives the variance of our gradient estimates to zero as the model improves. Applying our technique to GSM8K and the tasks in BIG-Bench Hard, we find that this MCMC-EM fine-tuning technique typically improves the model's accuracy on held-out examples more than STaR or prompt-tuning with or without CoT.
PDF110December 15, 2024