Optimisation des raisonneurs en chaîne de pensée par minimisation de la variance du gradient dans l'échantillonnage par rejet et l'apprentissage par renforcement
Optimizing Chain-of-Thought Reasoners via Gradient Variance Minimization in Rejection Sampling and RL
May 5, 2025
Auteurs: Jiarui Yao, Yifan Hao, Hanning Zhang, Hanze Dong, Wei Xiong, Nan Jiang, Tong Zhang
cs.AI
Résumé
Le raisonnement en chaîne de pensée (Chain-of-Thought, CoT) dans les grands modèles de langage (LLMs) peut être formalisé comme un problème de variable latente, où le modèle doit générer des étapes de raisonnement intermédiaires. Bien que des approches antérieures telles que le fine-tuning itératif par récompense classée (RAFT) se soient appuyées sur de telles formulations, elles appliquent généralement des budgets d'inférence uniformes pour tous les prompts, ce qui ne tient pas compte de la variabilité de la difficulté et du comportement de convergence. Ce travail identifie le principal goulot d'étranglement dans l'entraînement CoT comme étant une estimation inefficace du gradient stochastique due à des stratégies d'échantillonnage statiques. Nous proposons GVM-RAFT, une stratégie dynamique d'allocation d'échantillons spécifique au prompt, conçue pour minimiser la variance du gradient stochastique sous une contrainte de budget computationnel. La méthode alloue dynamiquement les ressources computationnelles en surveillant les taux d'acceptation des prompts et les normes du gradient stochastique, garantissant ainsi que la variance résultante du gradient est minimisée. Notre analyse théorique montre que la stratégie d'échantillonnage dynamique proposée conduit à des garanties de convergence accélérées sous des conditions appropriées. Les expériences sur le raisonnement mathématique montrent que GVM-RAFT atteint une accélération de 2 à 4 fois et des améliorations considérables en précision par rapport à RAFT standard. La stratégie d'échantillonnage dynamique proposée est générale et peut être intégrée dans d'autres algorithmes d'apprentissage par renforcement, tels que GRPO, conduisant à des améliorations similaires en convergence et en précision de test. Notre code est disponible à l'adresse https://github.com/RLHFlow/GVM.
English
Chain-of-thought (CoT) reasoning in large language models (LLMs) can be
formalized as a latent variable problem, where the model needs to generate
intermediate reasoning steps. While prior approaches such as iterative
reward-ranked fine-tuning (RAFT) have relied on such formulations, they
typically apply uniform inference budgets across prompts, which fails to
account for variability in difficulty and convergence behavior. This work
identifies the main bottleneck in CoT training as inefficient stochastic
gradient estimation due to static sampling strategies. We propose GVM-RAFT, a
prompt-specific Dynamic Sample Allocation Strategy designed to minimize
stochastic gradient variance under a computational budget constraint. The
method dynamically allocates computational resources by monitoring prompt
acceptance rates and stochastic gradient norms, ensuring that the resulting
gradient variance is minimized. Our theoretical analysis shows that the
proposed dynamic sampling strategy leads to accelerated convergence guarantees
under suitable conditions. Experiments on mathematical reasoning show that
GVM-RAFT achieves a 2-4x speedup and considerable accuracy improvements over
vanilla RAFT. The proposed dynamic sampling strategy is general and can be
incorporated into other reinforcement learning algorithms, such as GRPO,
leading to similar improvements in convergence and test accuracy. Our code is
available at https://github.com/RLHFlow/GVM.Summary
AI-Generated Summary