拒絶サンプリングと強化学習における勾配分散最小化による 連鎖的思考推論器の最適化
Optimizing Chain-of-Thought Reasoners via Gradient Variance Minimization in Rejection Sampling and RL
May 5, 2025
著者: Jiarui Yao, Yifan Hao, Hanning Zhang, Hanze Dong, Wei Xiong, Nan Jiang, Tong Zhang
cs.AI
要旨
大規模言語モデル(LLM)における連鎖的思考(Chain-of-Thought, CoT)推論は、モデルが中間的な推論ステップを生成する必要がある潜在変数問題として形式化することができます。これまでのアプローチ、例えば反復的報酬ランク付きファインチューニング(RAFT)などは、このような定式化に依存してきましたが、通常はプロンプト全体に均一な推論予算を適用しており、難易度や収束行動の変動を考慮していませんでした。本研究では、CoTトレーニングの主要なボトルネックとして、静的サンプリング戦略による確率的勾配推定の非効率性を特定しました。我々は、計算予算制約下で確率的勾配分散を最小化するために設計された、プロンプト固有の動的サンプル割り当て戦略であるGVM-RAFTを提案します。この手法は、プロンプトの受容率と確率的勾配ノルムを監視することで計算リソースを動的に割り当て、結果として得られる勾配分散が最小化されることを保証します。理論的分析により、提案された動的サンプリング戦略が適切な条件下で加速された収束保証をもたらすことが示されています。数学的推論に関する実験では、GVM-RAFTがバニラRAFTと比較して2~4倍の高速化と精度の大幅な向上を達成しました。提案された動的サンプリング戦略は汎用的であり、GRPOなどの他の強化学習アルゴリズムに組み込むことができ、同様の収束とテスト精度の向上をもたらします。コードはhttps://github.com/RLHFlow/GVMで公開されています。
English
Chain-of-thought (CoT) reasoning in large language models (LLMs) can be
formalized as a latent variable problem, where the model needs to generate
intermediate reasoning steps. While prior approaches such as iterative
reward-ranked fine-tuning (RAFT) have relied on such formulations, they
typically apply uniform inference budgets across prompts, which fails to
account for variability in difficulty and convergence behavior. This work
identifies the main bottleneck in CoT training as inefficient stochastic
gradient estimation due to static sampling strategies. We propose GVM-RAFT, a
prompt-specific Dynamic Sample Allocation Strategy designed to minimize
stochastic gradient variance under a computational budget constraint. The
method dynamically allocates computational resources by monitoring prompt
acceptance rates and stochastic gradient norms, ensuring that the resulting
gradient variance is minimized. Our theoretical analysis shows that the
proposed dynamic sampling strategy leads to accelerated convergence guarantees
under suitable conditions. Experiments on mathematical reasoning show that
GVM-RAFT achieves a 2-4x speedup and considerable accuracy improvements over
vanilla RAFT. The proposed dynamic sampling strategy is general and can be
incorporated into other reinforcement learning algorithms, such as GRPO,
leading to similar improvements in convergence and test accuracy. Our code is
available at https://github.com/RLHFlow/GVM.Summary
AI-Generated Summary