稀疏邏輯採樣:加速大型語言模型中的知識蒸餾
Sparse Logit Sampling: Accelerating Knowledge Distillation in LLMs
March 21, 2025
作者: Anshumann, Mohd Abbas Zaidi, Akhil Kedia, Jinwoo Ahn, Taehwak Kwon, Kangwook Lee, Haejun Lee, Joohyung Lee
cs.AI
摘要
知識蒸餾可以成為一種成本效益高的技術,用於在大型語言模型中提煉知識,前提是教師模型的輸出logits能夠被預先計算並緩存。然而,將此技術成功應用於預訓練階段仍屬未充分探索的領域。在本研究中,我們證明了諸如緩存Top-K概率等直觀的稀疏知識蒸餾方法,雖然易於理解,但會向學生模型提供教師概率分佈的偏差估計,導致性能與校準效果欠佳。我們提出了一種基於重要性採樣的方法——`隨機採樣知識蒸餾`,該方法提供無偏估計,在期望上保持梯度不變,且只需存儲顯著稀疏的logits。與基於交叉熵的訓練相比,我們的方法在僅增加少量額外開銷(<10%)的情況下,能夠加速學生模型的訓練,同時在從300M到3B的不同模型規模範圍內,保持與完整蒸餾相媲美的競爭性能。
English
Knowledge distillation can be a cost-effective technique to distill knowledge
in Large Language Models, if the teacher output logits can be pre-computed and
cached. However, successfully applying this to pre-training remains largely
unexplored. In this work, we prove that naive approaches for sparse knowledge
distillation such as caching Top-K probabilities, while intuitive, provide
biased estimates of teacher probability distribution to the student, resulting
in suboptimal performance and calibration. We propose an
importance-sampling-based method `Random Sampling Knowledge Distillation',
which provides unbiased estimates, preserves the gradient in expectation, and
requires storing significantly sparser logits. Our method enables faster
training of student models with marginal overhead (<10%) compared to
cross-entropy based training, while maintaining competitive performance
compared to full distillation, across a range of model sizes from 300M to 3B.