大規模言語モデルを用いた数学的推論学習におけるスケーリング関係
Scaling Relationship on Learning Mathematical Reasoning with Large Language Models
August 3, 2023
著者: Zheng Yuan, Hongyi Yuan, Chengpeng Li, Guanting Dong, Chuanqi Tan, Chang Zhou
cs.AI
要旨
大規模言語モデル(LLM)にとって、数学的推論は困難なタスクであり、そのスケーリング特性とLLMの能力との関係は十分に探究されていない。本論文では、事前学習損失、教師ありデータ量、および拡張データ量が、教師ありLLMの推論性能にどのように影響するかを調査する。事前学習損失は、モデルのパラメータ数よりも性能の優れた指標であることを見出した。異なる量の教師ありデータを用いて教師ありファインチューニング(SFT)を適用し、データ量とモデル性能の間に対数線形関係が存在することを実証的に確認し、より優れたモデルは拡張された教師ありデータセットでの改善が少ないことを発見した。人間の労力をかけずにモデル性能を向上させるためにより多くのデータサンプルを拡張するために、Rejection Sampling Fine-Tuning(RFT)を提案する。RFTは、教師ありモデルを使用して正しい推論パスを生成し、拡張ファインチューニングデータセットとして収集する。より多様な推論パスを含む拡張サンプルを用いることで、RFTはLLMの数学的推論性能をさらに向上させることがわかった。また、RFTは性能の低いLLMに対してより大きな改善をもたらすことも発見した。さらに、複数のモデルからのリジェクトサンプルを組み合わせることで、LLaMA-7Bの精度を49.3%に押し上げ、教師ありファインチューニング(SFT)の精度35.9%を大幅に上回る結果を得た。
English
Mathematical reasoning is a challenging task for large language models
(LLMs), while the scaling relationship of it with respect to LLM capacity is
under-explored. In this paper, we investigate how the pre-training loss,
supervised data amount, and augmented data amount influence the reasoning
performances of a supervised LLM. We find that pre-training loss is a better
indicator of the model's performance than the model's parameter count. We apply
supervised fine-tuning (SFT) with different amounts of supervised data and
empirically find a log-linear relation between data amount and model
performance, and we find better models improve less with enlarged supervised
datasets. To augment more data samples for improving model performances without
any human effort, we propose to apply Rejection sampling Fine-Tuning (RFT). RFT
uses supervised models to generate and collect correct reasoning paths as
augmented fine-tuning datasets. We find with augmented samples containing more
distinct reasoning paths, RFT improves mathematical reasoning performance more
for LLMs. We also find RFT brings more improvement for less performant LLMs.
Furthermore, we combine rejection samples from multiple models which push
LLaMA-7B to an accuracy of 49.3% and outperforms the supervised fine-tuning
(SFT) accuracy of 35.9% significantly.