潜在変数推論による連鎖思考のトレーニング
Training Chain-of-Thought via Latent-Variable Inference
November 28, 2023
著者: Du Phan, Matthew D. Hoffman, David Dohan, Sholto Douglas, Tuan Anh Le, Aaron Parisi, Pavel Sountsov, Charles Sutton, Sharad Vikram, Rif A. Saurous
cs.AI
要旨
大規模言語モデル(LLM)は、「連鎖的思考」(Chain-of-Thought, CoT)プロンプトを使用して段階的に答えを導くよう指示されると、問題をより正確かつ解釈可能に解決します。また、特定のタスクにおけるLLMの性能を向上させるために、教師ありファインチューニングを行うことができます。これは、ラベル付きトレーニングセットから正解の対数尤度の平均を最大化するために、調整可能なパラメータに対して勾配上昇法を使用するものです。CoTと教師ありチューニングを単純に組み合わせる場合、正解だけでなく、その答えに至る詳細な論理(rationale)の教師データも必要となりますが、これらの論理を手作業で作成するのはコストがかかります。代わりに、我々は、CoTプロンプトを使用して正解を生成する際の周辺対数尤度を最大化するファインチューニング戦略を提案します。これは、すべての可能な論理を近似的に平均化するものです。核心的な課題は、正解を条件とした論理の事後分布からのサンプリングです。これを解決するために、自己学習推論器(STaR)、メモ化されたウェイクスリープ、マルコフスコア上昇法、および持続的コントラスティブダイバージェンスに着想を得た、シンプルなマルコフ連鎖モンテカルロ(MCMC)期待値最大化(EM)アルゴリズムを使用します。このアルゴリズムは、モデルが改善されるにつれて勾配推定の分散をゼロに近づける新しい制御変数技術も導入します。GSM8KおよびBIG-Bench Hardのタスクにこの技術を適用した結果、このMCMC-EMファインチューニング技術は、CoTの有無にかかわらず、STaRやプロンプトチューニングよりも、検証データに対するモデルの精度を向上させることが一般的に確認されました。
English
Large language models (LLMs) solve problems more accurately and interpretably
when instructed to work out the answer step by step using a
``chain-of-thought'' (CoT) prompt. One can also improve LLMs' performance on a
specific task by supervised fine-tuning, i.e., by using gradient ascent on some
tunable parameters to maximize the average log-likelihood of correct answers
from a labeled training set. Naively combining CoT with supervised tuning
requires supervision not just of the correct answers, but also of detailed
rationales that lead to those answers; these rationales are expensive to
produce by hand. Instead, we propose a fine-tuning strategy that tries to
maximize the marginal log-likelihood of generating a correct answer
using CoT prompting, approximately averaging over all possible rationales. The
core challenge is sampling from the posterior over rationales conditioned on
the correct answer; we address it using a simple Markov-chain Monte Carlo
(MCMC) expectation-maximization (EM) algorithm inspired by the self-taught
reasoner (STaR), memoized wake-sleep, Markovian score climbing, and persistent
contrastive divergence. This algorithm also admits a novel control-variate
technique that drives the variance of our gradient estimates to zero as the
model improves. Applying our technique to GSM8K and the tasks in BIG-Bench
Hard, we find that this MCMC-EM fine-tuning technique typically improves the
model's accuracy on held-out examples more than STaR or prompt-tuning with or
without CoT.