Mise à l'échelle efficace de la longueur de pré-entraînement
Efficient Pretraining Length Scaling
April 21, 2025
Auteurs: Bohong Wu, Shen Yan, Sijun Zhang, Jianqiao Lu, Yutao Zeng, Ya Wang, Xun Zhou
cs.AI
Résumé
Les récents progrès des grands modèles de langage ont démontré l'efficacité de la mise à l'échelle de la longueur lors du post-entraînement, mais son potentiel pendant le pré-entraînement reste sous-exploré. Nous présentons le Parallel Hidden Decoding Transformer (PHD-Transformer), un cadre novateur qui permet une mise à l'échelle de la longueur efficace pendant le pré-entraînement tout en maintenant l'efficacité de l'inférence. Le PHD-Transformer y parvient grâce à une stratégie innovante de gestion du cache KV qui distingue les tokens originaux des tokens de décodage cachés. En conservant uniquement le cache KV des tokens originaux pour les dépendances à longue portée tout en éliminant immédiatement les tokens de décodage cachés après leur utilisation, notre approche maintient la même taille de cache KV que le transformer classique tout en permettant une mise à l'échelle de la longueur efficace. Pour améliorer encore les performances, nous introduisons deux variantes optimisées : PHD-SWA utilise une attention par fenêtre glissante pour préserver les dépendances locales, tandis que PHD-CSWA met en œuvre une attention par fenêtre glissante par morceaux pour éliminer la croissance linéaire du temps de pré-remplissage. Des expériences approfondies démontrent des améliorations constantes sur plusieurs benchmarks.
English
Recent advances in large language models have demonstrated the effectiveness
of length scaling during post-training, yet its potential in pre-training
remains underexplored. We present the Parallel Hidden Decoding Transformer
(PHD-Transformer), a novel framework that enables efficient length
scaling during pre-training while maintaining inference efficiency.
PHD-Transformer achieves this through an innovative KV cache
management strategy that distinguishes between original tokens and hidden
decoding tokens. By retaining only the KV cache of original tokens for
long-range dependencies while immediately discarding hidden decoding tokens
after use, our approach maintains the same KV cache size as the vanilla
transformer while enabling effective length scaling. To further enhance
performance, we introduce two optimized variants: PHD-SWA employs
sliding window attention to preserve local dependencies, while
PHD-CSWA implements chunk-wise sliding window attention to eliminate
linear growth in pre-filling time. Extensive experiments demonstrate consistent
improvements across multiple benchmarks.Summary
AI-Generated Summary