ChatPaper.aiChatPaper

通过梯度方差最小化优化链式思维推理器: 拒绝采样与强化学习中的应用

Optimizing Chain-of-Thought Reasoners via Gradient Variance Minimization in Rejection Sampling and RL

May 5, 2025
作者: Jiarui Yao, Yifan Hao, Hanning Zhang, Hanze Dong, Wei Xiong, Nan Jiang, Tong Zhang
cs.AI

摘要

大语言模型(LLMs)中的思维链(CoT)推理可形式化为一个潜在变量问题,其中模型需生成中间推理步骤。尽管先前的方法如迭代奖励排序微调(RAFT)依赖此类形式化,但它们通常对提示采用统一的推理预算,未能考虑到难度与收敛行为的差异性。本工作指出CoT训练中的主要瓶颈在于静态采样策略导致的随机梯度估计效率低下。我们提出了GVM-RAFT,一种针对提示的动态样本分配策略,旨在计算预算约束下最小化随机梯度方差。该方法通过监控提示接受率与随机梯度范数,动态分配计算资源,确保所得梯度方差最小化。理论分析表明,在适当条件下,所提出的动态采样策略能加速收敛保证。数学推理实验显示,GVM-RAFT相比原始RAFT实现了2-4倍的加速及显著的准确率提升。该动态采样策略具有通用性,可融入其他强化学习算法,如GRPO,带来类似的收敛速度与测试准确率提升。我们的代码公开于https://github.com/RLHFlow/GVM。
English
Chain-of-thought (CoT) reasoning in large language models (LLMs) can be formalized as a latent variable problem, where the model needs to generate intermediate reasoning steps. While prior approaches such as iterative reward-ranked fine-tuning (RAFT) have relied on such formulations, they typically apply uniform inference budgets across prompts, which fails to account for variability in difficulty and convergence behavior. This work identifies the main bottleneck in CoT training as inefficient stochastic gradient estimation due to static sampling strategies. We propose GVM-RAFT, a prompt-specific Dynamic Sample Allocation Strategy designed to minimize stochastic gradient variance under a computational budget constraint. The method dynamically allocates computational resources by monitoring prompt acceptance rates and stochastic gradient norms, ensuring that the resulting gradient variance is minimized. Our theoretical analysis shows that the proposed dynamic sampling strategy leads to accelerated convergence guarantees under suitable conditions. Experiments on mathematical reasoning show that GVM-RAFT achieves a 2-4x speedup and considerable accuracy improvements over vanilla RAFT. The proposed dynamic sampling strategy is general and can be incorporated into other reinforcement learning algorithms, such as GRPO, leading to similar improvements in convergence and test accuracy. Our code is available at https://github.com/RLHFlow/GVM.

Summary

AI-Generated Summary

PDF181May 6, 2025