Entrenamiento del Pensamiento en Cadena mediante Inferencia de Variables Latentes
Training Chain-of-Thought via Latent-Variable Inference
November 28, 2023
Autores: Du Phan, Matthew D. Hoffman, David Dohan, Sholto Douglas, Tuan Anh Le, Aaron Parisi, Pavel Sountsov, Charles Sutton, Sharad Vikram, Rif A. Saurous
cs.AI
Resumen
Los modelos de lenguaje de gran escala (LLMs, por sus siglas en inglés) resuelven problemas con mayor precisión y capacidad de interpretación cuando se les instruye para trabajar en la respuesta paso a paso utilizando un prompt de "cadena de pensamiento" (CoT, por sus siglas en inglés). También se puede mejorar el rendimiento de los LLMs en una tarea específica mediante ajuste fino supervisado, es decir, utilizando el ascenso de gradiente sobre algunos parámetros ajustables para maximizar la log-verosimilitud promedio de las respuestas correctas en un conjunto de entrenamiento etiquetado. Combinar de manera ingenua CoT con el ajuste supervisado requiere supervisión no solo de las respuestas correctas, sino también de las razones detalladas que llevan a esas respuestas; estas razones son costosas de producir manualmente. En su lugar, proponemos una estrategia de ajuste fino que intenta maximizar la log-verosimilitud marginal de generar una respuesta correcta utilizando el prompt CoT, promediando aproximadamente sobre todas las razones posibles. El desafío principal es muestrear a partir de la distribución posterior sobre las razones condicionadas a la respuesta correcta; lo abordamos utilizando un algoritmo simple de maximización de expectativas (EM) basado en cadenas de Markov Monte Carlo (MCMC), inspirado en el razonador autodidacta (STaR), el método de sueño-memorización (memoized wake-sleep), la escalada de puntuación markoviana y la divergencia contrastiva persistente. Este algoritmo también admite una técnica novedosa de control de variación que reduce la varianza de nuestras estimaciones de gradiente a cero a medida que el modelo mejora. Al aplicar nuestra técnica a GSM8K y a las tareas de BIG-Bench Hard, encontramos que este método de ajuste fino MCMC-EM generalmente mejora la precisión del modelo en ejemplos de prueba más que STaR o el ajuste de prompts con o sin CoT.
English
Large language models (LLMs) solve problems more accurately and interpretably
when instructed to work out the answer step by step using a
``chain-of-thought'' (CoT) prompt. One can also improve LLMs' performance on a
specific task by supervised fine-tuning, i.e., by using gradient ascent on some
tunable parameters to maximize the average log-likelihood of correct answers
from a labeled training set. Naively combining CoT with supervised tuning
requires supervision not just of the correct answers, but also of detailed
rationales that lead to those answers; these rationales are expensive to
produce by hand. Instead, we propose a fine-tuning strategy that tries to
maximize the marginal log-likelihood of generating a correct answer
using CoT prompting, approximately averaging over all possible rationales. The
core challenge is sampling from the posterior over rationales conditioned on
the correct answer; we address it using a simple Markov-chain Monte Carlo
(MCMC) expectation-maximization (EM) algorithm inspired by the self-taught
reasoner (STaR), memoized wake-sleep, Markovian score climbing, and persistent
contrastive divergence. This algorithm also admits a novel control-variate
technique that drives the variance of our gradient estimates to zero as the
model improves. Applying our technique to GSM8K and the tasks in BIG-Bench
Hard, we find that this MCMC-EM fine-tuning technique typically improves the
model's accuracy on held-out examples more than STaR or prompt-tuning with or
without CoT.