スパースロジットサンプリング:大規模言語モデルにおける知識蒸留の高速化
Sparse Logit Sampling: Accelerating Knowledge Distillation in LLMs
March 21, 2025
著者: Anshumann, Mohd Abbas Zaidi, Akhil Kedia, Jinwoo Ahn, Taehwak Kwon, Kangwook Lee, Haejun Lee, Joohyung Lee
cs.AI
要旨
知識蒸留は、教師モデルの出力ロジットを事前計算してキャッシュできる場合、大規模言語モデルにおける知識の抽出において費用対効果の高い技術となり得ます。しかし、これを事前学習に適用することは、まだほとんど検討されていません。本研究では、Top-K確率をキャッシュするといった素朴なスパース知識蒸留のアプローチは直感的ではあるものの、教師の確率分布を学生モデルに偏った形で推定し、結果として最適でない性能とキャリブレーションをもたらすことを証明します。我々は、重要度サンプリングに基づく手法「ランダムサンプリング知識蒸留」を提案します。この手法は不偏推定を提供し、期待値において勾配を保存し、さらに大幅にスパースなロジットの保存を可能にします。我々の手法は、300Mから3Bまでの様々なモデルサイズにおいて、完全な蒸留と比較して競争力のある性能を維持しつつ、クロスエントロピーに基づく学習と比較してわずかなオーバーヘッド(10%未満)で学生モデルのより高速な学習を実現します。
English
Knowledge distillation can be a cost-effective technique to distill knowledge
in Large Language Models, if the teacher output logits can be pre-computed and
cached. However, successfully applying this to pre-training remains largely
unexplored. In this work, we prove that naive approaches for sparse knowledge
distillation such as caching Top-K probabilities, while intuitive, provide
biased estimates of teacher probability distribution to the student, resulting
in suboptimal performance and calibration. We propose an
importance-sampling-based method `Random Sampling Knowledge Distillation',
which provides unbiased estimates, preserves the gradient in expectation, and
requires storing significantly sparser logits. Our method enables faster
training of student models with marginal overhead (<10%) compared to
cross-entropy based training, while maintaining competitive performance
compared to full distillation, across a range of model sizes from 300M to 3B.Summary
AI-Generated Summary