BOND: Alinhando LLMs com Destilação Best-of-N
BOND: Aligning LLMs with Best-of-N Distillation
July 19, 2024
Autores: Pier Giuseppe Sessa, Robert Dadashi, Léonard Hussenot, Johan Ferret, Nino Vieillard, Alexandre Ramé, Bobak Shariari, Sarah Perrin, Abe Friesen, Geoffrey Cideron, Sertan Girgin, Piotr Stanczyk, Andrea Michi, Danila Sinopalnikov, Sabela Ramos, Amélie Héliou, Aliaksei Severyn, Matt Hoffman, Nikola Momchev, Olivier Bachem
cs.AI
Resumo
A aprendizagem por reforço a partir do feedback humano (RLHF) é um fator-chave de qualidade e segurança em modelos de linguagem grandes de última geração. No entanto, uma estratégia surpreendentemente simples e forte durante a inferência é a seleção do Melhor-de-N amostragem que escolhe a melhor geração entre N candidatos. Neste artigo, propomos a Destilação Melhor-de-N (BOND), um novo algoritmo RLHF que busca emular o Melhor-de-N, mas sem o significativo custo computacional durante a inferência. Especificamente, BOND é um algoritmo de correspondência de distribuição que força a distribuição das gerações da política a se aproximar da distribuição Melhor-de-N. Utilizamos a divergência de Jeffreys (uma combinação linear de KL direta e reversa) para equilibrar entre cobertura de modo e comportamento de busca de modo, e derivamos uma formulação iterativa que utiliza uma âncora móvel para eficiência. Demonstramos a eficácia de nossa abordagem e várias escolhas de projeto por meio de experimentos em sumarização abstrativa e modelos Gemma. Alinhar as políticas Gemma com BOND supera outros algoritmos RLHF ao melhorar os resultados em vários benchmarks.
English
Reinforcement learning from human feedback (RLHF) is a key driver of quality
and safety in state-of-the-art large language models. Yet, a surprisingly
simple and strong inference-time strategy is Best-of-N sampling that selects
the best generation among N candidates. In this paper, we propose Best-of-N
Distillation (BOND), a novel RLHF algorithm that seeks to emulate Best-of-N but
without its significant computational overhead at inference time. Specifically,
BOND is a distribution matching algorithm that forces the distribution of
generations from the policy to get closer to the Best-of-N distribution. We use
the Jeffreys divergence (a linear combination of forward and backward KL) to
balance between mode-covering and mode-seeking behavior, and derive an
iterative formulation that utilizes a moving anchor for efficiency. We
demonstrate the effectiveness of our approach and several design choices
through experiments on abstractive summarization and Gemma models. Aligning
Gemma policies with BOND outperforms other RLHF algorithms by improving results
on several benchmarks.