Treinamento de Cadeia de Pensamento via Inferência de Variáveis Latentes
Training Chain-of-Thought via Latent-Variable Inference
November 28, 2023
Autores: Du Phan, Matthew D. Hoffman, David Dohan, Sholto Douglas, Tuan Anh Le, Aaron Parisi, Pavel Sountsov, Charles Sutton, Sharad Vikram, Rif A. Saurous
cs.AI
Resumo
Modelos de linguagem de grande escala (LLMs) resolvem problemas de forma mais precisa e interpretável quando instruídos a trabalhar na resposta passo a passo usando um prompt de "cadeia de pensamento" (chain-of-thought, CoT). Também é possível melhorar o desempenho dos LLMs em uma tarefa específica por meio de ajuste fino supervisionado, ou seja, usando gradiente ascendente em alguns parâmetros ajustáveis para maximizar a média da log-verossimilhança das respostas corretas de um conjunto de treinamento rotulado. Combinar CoT com ajuste fino supervisionado de forma ingênua exige supervisão não apenas das respostas corretas, mas também das justificativas detalhadas que levam a essas respostas; essas justificativas são caras para serem produzidas manualmente. Em vez disso, propomos uma estratégia de ajuste fino que tenta maximizar a log-verossimilhança marginal de gerar uma resposta correta usando o prompt CoT, aproximadamente calculando a média sobre todas as justificativas possíveis. O desafio central é amostrar da distribuição posterior sobre as justificativas condicionadas à resposta correta; abordamos isso usando um algoritmo simples de maximização de expectativa (EM) com cadeia de Markov Monte Carlo (MCMC), inspirado no raciocínio autodidata (STaR), no método memoizado wake-sleep, na subida de pontuação markoviana e na divergência contrastiva persistente. Esse algoritmo também admite uma técnica inovadora de variável de controle que reduz a variância de nossas estimativas de gradiente a zero à medida que o modelo melhora. Aplicando nossa técnica ao GSM8K e às tarefas do BIG-Bench Hard, descobrimos que essa técnica de ajuste fino MCMC-EM geralmente melhora a precisão do modelo em exemplos de teste mais do que o STaR ou o ajuste de prompt com ou sem CoT.
English
Large language models (LLMs) solve problems more accurately and interpretably
when instructed to work out the answer step by step using a
``chain-of-thought'' (CoT) prompt. One can also improve LLMs' performance on a
specific task by supervised fine-tuning, i.e., by using gradient ascent on some
tunable parameters to maximize the average log-likelihood of correct answers
from a labeled training set. Naively combining CoT with supervised tuning
requires supervision not just of the correct answers, but also of detailed
rationales that lead to those answers; these rationales are expensive to
produce by hand. Instead, we propose a fine-tuning strategy that tries to
maximize the marginal log-likelihood of generating a correct answer
using CoT prompting, approximately averaging over all possible rationales. The
core challenge is sampling from the posterior over rationales conditioned on
the correct answer; we address it using a simple Markov-chain Monte Carlo
(MCMC) expectation-maximization (EM) algorithm inspired by the self-taught
reasoner (STaR), memoized wake-sleep, Markovian score climbing, and persistent
contrastive divergence. This algorithm also admits a novel control-variate
technique that drives the variance of our gradient estimates to zero as the
model improves. Applying our technique to GSM8K and the tasks in BIG-Bench
Hard, we find that this MCMC-EM fine-tuning technique typically improves the
model's accuracy on held-out examples more than STaR or prompt-tuning with or
without CoT.