ChatPaper.aiChatPaper

稀疏Logit采样:加速大语言模型中的知识蒸馏

Sparse Logit Sampling: Accelerating Knowledge Distillation in LLMs

March 21, 2025
作者: Anshumann, Mohd Abbas Zaidi, Akhil Kedia, Jinwoo Ahn, Taehwak Kwon, Kangwook Lee, Haejun Lee, Joohyung Lee
cs.AI

摘要

知识蒸馏作为一种成本效益显著的技术,能够从大型语言模型中提炼知识,前提是教师模型的输出logits可以被预先计算并缓存。然而,将其成功应用于预训练阶段仍是一个尚未充分探索的领域。在本研究中,我们证实了诸如缓存Top-K概率这类直观的稀疏知识蒸馏方法,虽然简单,却会导致学生模型对教师概率分布的估计产生偏差,从而造成性能与校准效果不佳。为此,我们提出了一种基于重要性采样的方法——“随机采样知识蒸馏”,该方法能够提供无偏估计,在期望上保持梯度不变,并且只需存储更为稀疏的logits。与基于交叉熵的训练相比,我们的方法使学生模型的训练速度显著提升,额外开销极小(<10%),同时在模型规模从300M到3B的广泛范围内,保持了与完整蒸馏相媲美的性能表现。
English
Knowledge distillation can be a cost-effective technique to distill knowledge in Large Language Models, if the teacher output logits can be pre-computed and cached. However, successfully applying this to pre-training remains largely unexplored. In this work, we prove that naive approaches for sparse knowledge distillation such as caching Top-K probabilities, while intuitive, provide biased estimates of teacher probability distribution to the student, resulting in suboptimal performance and calibration. We propose an importance-sampling-based method `Random Sampling Knowledge Distillation', which provides unbiased estimates, preserves the gradient in expectation, and requires storing significantly sparser logits. Our method enables faster training of student models with marginal overhead (<10%) compared to cross-entropy based training, while maintaining competitive performance compared to full distillation, across a range of model sizes from 300M to 3B.

Summary

AI-Generated Summary

PDF52March 27, 2025