거부 샘플링과 강화 학습에서 그래디언트 분산 최소화를 통한 사고 연쇄 추론기 최적화
Optimizing Chain-of-Thought Reasoners via Gradient Variance Minimization in Rejection Sampling and RL
May 5, 2025
저자: Jiarui Yao, Yifan Hao, Hanning Zhang, Hanze Dong, Wei Xiong, Nan Jiang, Tong Zhang
cs.AI
초록
대규모 언어 모델(LLMs)에서의 사고 연쇄(Chain-of-Thought, CoT) 추론은 모델이 중간 추론 단계를 생성해야 하는 잠재 변수 문제로 공식화될 수 있습니다. 이전의 반복적 보험 순위 미세 조정(RAFT)과 같은 접근 방식은 이러한 공식화에 의존해 왔지만, 일반적으로 프롬프트 전반에 균일한 추론 예산을 적용하여 난이도와 수렴 행동의 변동성을 고려하지 못했습니다. 본 연구는 CoT 훈련의 주요 병목 현상을 정적 샘플링 전략으로 인한 비효율적인 확률적 경사 추정으로 식별합니다. 우리는 계산 예산 제약 하에서 확률적 경사 분산을 최소화하도록 설계된 프롬프트 특화 동적 샘플 할당 전략인 GVM-RAFT를 제안합니다. 이 방법은 프롬프트 수용률과 확률적 경사 노름을 모니터링하여 계산 자원을 동적으로 할당함으로써 결과적인 경사 분산이 최소화되도록 보장합니다. 우리의 이론적 분석은 제안된 동적 샘플링 전략이 적절한 조건 하에서 가속화된 수렴 보장을 이끌어냄을 보여줍니다. 수학적 추론 실험에서 GVM-RAFT는 기본 RAFT 대비 2-4배의 속도 향상과 상당한 정확도 개선을 달성했습니다. 제안된 동적 샘플링 전략은 일반적이며 GRPO와 같은 다른 강화 학습 알고리즘에 통합될 수 있어 수렴 및 테스트 정확도에서 유사한 개선을 이끌어냅니다. 우리의 코드는 https://github.com/RLHFlow/GVM에서 확인할 수 있습니다.
English
Chain-of-thought (CoT) reasoning in large language models (LLMs) can be
formalized as a latent variable problem, where the model needs to generate
intermediate reasoning steps. While prior approaches such as iterative
reward-ranked fine-tuning (RAFT) have relied on such formulations, they
typically apply uniform inference budgets across prompts, which fails to
account for variability in difficulty and convergence behavior. This work
identifies the main bottleneck in CoT training as inefficient stochastic
gradient estimation due to static sampling strategies. We propose GVM-RAFT, a
prompt-specific Dynamic Sample Allocation Strategy designed to minimize
stochastic gradient variance under a computational budget constraint. The
method dynamically allocates computational resources by monitoring prompt
acceptance rates and stochastic gradient norms, ensuring that the resulting
gradient variance is minimized. Our theoretical analysis shows that the
proposed dynamic sampling strategy leads to accelerated convergence guarantees
under suitable conditions. Experiments on mathematical reasoning show that
GVM-RAFT achieves a 2-4x speedup and considerable accuracy improvements over
vanilla RAFT. The proposed dynamic sampling strategy is general and can be
incorporated into other reinforcement learning algorithms, such as GRPO,
leading to similar improvements in convergence and test accuracy. Our code is
available at https://github.com/RLHFlow/GVM.Summary
AI-Generated Summary