ChatPaper.aiChatPaper

GLoRe: 大規模言語モデルの推論をグローバルおよびローカルな精緻化によって改善するタイミング、場所、方法

GLoRe: When, Where, and How to Improve LLM Reasoning via Global and Local Refinements

February 13, 2024
著者: Alex Havrilla, Sharath Raparthy, Christoforus Nalmpantis, Jane Dwivedi-Yu, Maksym Zhuravinskyi, Eric Hambro, Roberta Railneau
cs.AI

要旨

最先端の言語モデルは、数学、科学、またはコーディングタスクにおいて、印象的な推論改善能力を示すことがあります。しかし、最近の研究では、外部フィードバックにアクセスできない場合、最良のモデルでさえ、いつ、どこで改善すべきかを特定するのに苦労することが示されています。最終的な答えの正しさを予測し、いつ改善すべきかを示すOutcome-based Reward Models(ORMs)は、改善のタイミングを決定するための便利な解決策を提供します。一方、中間ステップの正しさを予測するProcess Based Reward Models(PRMs)は、どこで改善すべきかを示すために使用できますが、これらは広範な人間のアノテーションを必要とするため、訓練にコストがかかります。本論文では、合成データのみを用いて訓練され、最適ポリシーまたはV^{star}の将来の報酬を近似するStepwise ORMs(SORMs)を提案します。具体的には、SORMsは、現在のポリシーを複数回サンプリングした場合の最終的な答えの正しさを予測するように訓練されます(ORMsの場合のように一度だけではなく)。実験結果は、SORMsがORMsと比較して、誤った推論ステップをより正確に検出できることを示しており、改善を行う際の下流の精度を向上させます。次に、質問と草案の解決策を入力として受け取り、修正された解決策を予測するグローバル改善モデルと、最初の推論エラーの位置を示す批評も入力として受け取るローカル改善モデルを訓練します。両モデルの訓練データは、SORMの訓練に使用されたデータを再利用して合成的に生成します。グローバルとローカルの改善を組み合わせ、ORMをリランカーとして使用することで、個別の改善や、3つのサンプルのベストを上回る性能を発揮することがわかりました。この戦略により、RLで既にファインチューニングされたLLaMA-2 13BモデルのGSM8Kにおける精度を、貪欲サンプリング時に53%から65%に向上させることができます。
English
State-of-the-art language models can exhibit impressive reasoning refinement capabilities on math, science or coding tasks. However, recent work demonstrates that even the best models struggle to identify when and where to refine without access to external feedback. Outcome-based Reward Models (ORMs), trained to predict correctness of the final answer indicating when to refine, offer one convenient solution for deciding when to refine. Process Based Reward Models (PRMs), trained to predict correctness of intermediate steps, can then be used to indicate where to refine. But they are expensive to train, requiring extensive human annotations. In this paper, we propose Stepwise ORMs (SORMs) which are trained, only on synthetic data, to approximate the expected future reward of the optimal policy or V^{star}. More specifically, SORMs are trained to predict the correctness of the final answer when sampling the current policy many times (rather than only once as in the case of ORMs). Our experiments show that SORMs can more accurately detect incorrect reasoning steps compared to ORMs, thus improving downstream accuracy when doing refinements. We then train global refinement models, which take only the question and a draft solution as input and predict a corrected solution, and local refinement models which also take as input a critique indicating the location of the first reasoning error. We generate training data for both models synthetically by reusing data used to train the SORM. We find combining global and local refinements, using the ORM as a reranker, significantly outperforms either one individually, as well as a best of three sample baseline. With this strategy we can improve the accuracy of a LLaMA-2 13B model (already fine-tuned with RL) on GSM8K from 53\% to 65\% when greedily sampled.
PDF121December 15, 2024