ChatPaper.aiChatPaper

Optimierung von Chain-of-Thought-Reasonern durch Minimierung der Gradientenvarianz in Rejection Sampling und Reinforcement Learning

Optimizing Chain-of-Thought Reasoners via Gradient Variance Minimization in Rejection Sampling and RL

May 5, 2025
Autoren: Jiarui Yao, Yifan Hao, Hanning Zhang, Hanze Dong, Wei Xiong, Nan Jiang, Tong Zhang
cs.AI

Zusammenfassung

Chain-of-Thought (CoT)-Reasoning in großen Sprachmodellen (LLMs) kann als ein latentes Variablenproblem formalisiert werden, bei dem das Modell Zwischenschritte der Argumentation generieren muss. Während frühere Ansätze wie das iterative Reward-Ranked Fine-Tuning (RAFT) auf solchen Formulierungen basierten, wenden sie typischerweise einheitliche Inferenzbudgets über alle Prompts hinweg an, was die Variabilität in Schwierigkeit und Konvergenzverhalten nicht berücksichtigt. Diese Arbeit identifiziert den Hauptengpass im CoT-Training als ineffiziente stochastische Gradientenschätzung aufgrund statischer Sampling-Strategien. Wir schlagen GVM-RAFT vor, eine prompt-spezifische dynamische Sample-Allokationsstrategie, die darauf abzielt, die Varianz des stochastischen Gradienten unter einer Rechenbudgetbeschränkung zu minimieren. Die Methode weist Rechenressourcen dynamisch zu, indem sie die Akzeptanzraten der Prompts und die Normen der stochastischen Gradienten überwacht, um sicherzustellen, dass die resultierende Gradientenvarianz minimiert wird. Unsere theoretische Analyse zeigt, dass die vorgeschlagene dynamische Sampling-Strategie unter geeigneten Bedingungen zu beschleunigten Konvergenzgarantien führt. Experimente zur mathematischen Argumentation zeigen, dass GVM-RAFT eine 2-4-fache Beschleunigung und erhebliche Genauigkeitsverbesserungen gegenüber dem Standard-RAFT erreicht. Die vorgeschlagene dynamische Sampling-Strategie ist allgemein und kann in andere Reinforcement-Learning-Algorithmen wie GRPO integriert werden, was zu ähnlichen Verbesserungen in Konvergenz und Testgenauigkeit führt. Unser Code ist verfügbar unter https://github.com/RLHFlow/GVM.
English
Chain-of-thought (CoT) reasoning in large language models (LLMs) can be formalized as a latent variable problem, where the model needs to generate intermediate reasoning steps. While prior approaches such as iterative reward-ranked fine-tuning (RAFT) have relied on such formulations, they typically apply uniform inference budgets across prompts, which fails to account for variability in difficulty and convergence behavior. This work identifies the main bottleneck in CoT training as inefficient stochastic gradient estimation due to static sampling strategies. We propose GVM-RAFT, a prompt-specific Dynamic Sample Allocation Strategy designed to minimize stochastic gradient variance under a computational budget constraint. The method dynamically allocates computational resources by monitoring prompt acceptance rates and stochastic gradient norms, ensuring that the resulting gradient variance is minimized. Our theoretical analysis shows that the proposed dynamic sampling strategy leads to accelerated convergence guarantees under suitable conditions. Experiments on mathematical reasoning show that GVM-RAFT achieves a 2-4x speedup and considerable accuracy improvements over vanilla RAFT. The proposed dynamic sampling strategy is general and can be incorporated into other reinforcement learning algorithms, such as GRPO, leading to similar improvements in convergence and test accuracy. Our code is available at https://github.com/RLHFlow/GVM.

Summary

AI-Generated Summary

PDF181May 6, 2025