Training van Ketting-van-Gedachten via Latente-Variabele Inferentie
Training Chain-of-Thought via Latent-Variable Inference
November 28, 2023
Auteurs: Du Phan, Matthew D. Hoffman, David Dohan, Sholto Douglas, Tuan Anh Le, Aaron Parisi, Pavel Sountsov, Charles Sutton, Sharad Vikram, Rif A. Saurous
cs.AI
Samenvatting
Grote taalmodellen (LLMs) lossen problemen nauwkeuriger en interpreteerbaarder op wanneer ze worden geïnstrueerd om het antwoord stap voor stap uit te werken met behulp van een "chain-of-thought" (CoT) prompt. Men kan ook de prestaties van LLMs op een specifieke taak verbeteren door supervised fine-tuning, d.w.z. door gebruik te maken van gradient ascent op enkele afstelbare parameters om de gemiddelde log-waarschijnlijkheid van correcte antwoorden uit een gelabelde trainingsset te maximaliseren. Het naïef combineren van CoT met supervised tuning vereist niet alleen supervisie van de correcte antwoorden, maar ook van gedetailleerde redeneringen die tot die antwoorden leiden; deze redeneringen zijn kostbaar om handmatig te produceren. In plaats daarvan stellen we een fine-tuning strategie voor die probeert de marginale log-waarschijnlijkheid van het genereren van een correct antwoord met behulp van CoT prompting te maximaliseren, waarbij ongeveer gemiddeld wordt over alle mogelijke redeneringen. De kernuitdaging is het bemonsteren van de posterior over redeneringen geconditioneerd op het correcte antwoord; we pakken dit aan met een eenvoudig Markov-chain Monte Carlo (MCMC) expectation-maximization (EM) algoritme geïnspireerd door de self-taught reasoner (STaR), memoized wake-sleep, Markovian score climbing, en persistent contrastive divergence. Dit algoritme maakt ook gebruik van een nieuwe controle-variabele techniek die de variantie van onze gradient schattingen naar nul drijft naarmate het model verbetert. Door onze techniek toe te passen op GSM8K en de taken in BIG-Bench Hard, ontdekken we dat deze MCMC-EM fine-tuning techniek doorgaans de nauwkeurigheid van het model op achtergehouden voorbeelden meer verbetert dan STaR of prompt-tuning met of zonder CoT.
English
Large language models (LLMs) solve problems more accurately and interpretably
when instructed to work out the answer step by step using a
``chain-of-thought'' (CoT) prompt. One can also improve LLMs' performance on a
specific task by supervised fine-tuning, i.e., by using gradient ascent on some
tunable parameters to maximize the average log-likelihood of correct answers
from a labeled training set. Naively combining CoT with supervised tuning
requires supervision not just of the correct answers, but also of detailed
rationales that lead to those answers; these rationales are expensive to
produce by hand. Instead, we propose a fine-tuning strategy that tries to
maximize the marginal log-likelihood of generating a correct answer
using CoT prompting, approximately averaging over all possible rationales. The
core challenge is sampling from the posterior over rationales conditioned on
the correct answer; we address it using a simple Markov-chain Monte Carlo
(MCMC) expectation-maximization (EM) algorithm inspired by the self-taught
reasoner (STaR), memoized wake-sleep, Markovian score climbing, and persistent
contrastive divergence. This algorithm also admits a novel control-variate
technique that drives the variance of our gradient estimates to zero as the
model improves. Applying our technique to GSM8K and the tasks in BIG-Bench
Hard, we find that this MCMC-EM fine-tuning technique typically improves the
model's accuracy on held-out examples more than STaR or prompt-tuning with or
without CoT.