Addestramento della Catena di Pensiero tramite Inferenza a Variabili Latenti
Training Chain-of-Thought via Latent-Variable Inference
November 28, 2023
Autori: Du Phan, Matthew D. Hoffman, David Dohan, Sholto Douglas, Tuan Anh Le, Aaron Parisi, Pavel Sountsov, Charles Sutton, Sharad Vikram, Rif A. Saurous
cs.AI
Abstract
I modelli linguistici di grandi dimensioni (LLM) risolvono i problemi in modo più accurato e interpretabile quando vengono istruiti a elaborare la risposta passo dopo passo utilizzando un prompt a "catena di pensiero" (CoT). È anche possibile migliorare le prestazioni degli LLM su un compito specifico attraverso la messa a punto supervisionata, ovvero utilizzando l'ascesa del gradiente su alcuni parametri regolabili per massimizzare la log-verosimiglianza media delle risposte corrette da un insieme di addestramento etichettato. Combinare in modo ingenuo il CoT con la messa a punto supervisionata richiede non solo la supervisione delle risposte corrette, ma anche delle ragioni dettagliate che portano a tali risposte; queste ragioni sono costose da produrre manualmente. Proponiamo invece una strategia di messa a punto che cerca di massimizzare la log-verosimiglianza marginale di generare una risposta corretta utilizzando il prompting CoT, approssimando la media su tutte le possibili ragioni. La sfida principale è campionare dalla distribuzione a posteriori sulle ragioni condizionata alla risposta corretta; affrontiamo questo problema utilizzando un semplice algoritmo di massimizzazione delle aspettazioni (EM) basato su catene di Markov Monte Carlo (MCMC), ispirato dal ragionatore auto-apprendente (STaR), dal metodo memoized wake-sleep, dalla scalata del punteggio markoviano e dalla divergenza contrastiva persistente. Questo algoritmo ammette anche una nuova tecnica di controllo delle variabili che riduce la varianza delle nostre stime del gradiente a zero man mano che il modello migliora. Applicando la nostra tecnica a GSM8K e ai compiti di BIG-Bench Hard, scopriamo che questa tecnica di messa a punto MCMC-EM migliora tipicamente l'accuratezza del modello sugli esempi di test più di STaR o del prompt-tuning con o senza CoT.
English
Large language models (LLMs) solve problems more accurately and interpretably
when instructed to work out the answer step by step using a
``chain-of-thought'' (CoT) prompt. One can also improve LLMs' performance on a
specific task by supervised fine-tuning, i.e., by using gradient ascent on some
tunable parameters to maximize the average log-likelihood of correct answers
from a labeled training set. Naively combining CoT with supervised tuning
requires supervision not just of the correct answers, but also of detailed
rationales that lead to those answers; these rationales are expensive to
produce by hand. Instead, we propose a fine-tuning strategy that tries to
maximize the marginal log-likelihood of generating a correct answer
using CoT prompting, approximately averaging over all possible rationales. The
core challenge is sampling from the posterior over rationales conditioned on
the correct answer; we address it using a simple Markov-chain Monte Carlo
(MCMC) expectation-maximization (EM) algorithm inspired by the self-taught
reasoner (STaR), memoized wake-sleep, Markovian score climbing, and persistent
contrastive divergence. This algorithm also admits a novel control-variate
technique that drives the variance of our gradient estimates to zero as the
model improves. Applying our technique to GSM8K and the tasks in BIG-Bench
Hard, we find that this MCMC-EM fine-tuning technique typically improves the
model's accuracy on held-out examples more than STaR or prompt-tuning with or
without CoT.