Scalabilità Efficiente della Lunghezza del Pretraining
Efficient Pretraining Length Scaling
April 21, 2025
Autori: Bohong Wu, Shen Yan, Sijun Zhang, Jianqiao Lu, Yutao Zeng, Ya Wang, Xun Zhou
cs.AI
Abstract
I recenti progressi nei modelli linguistici di grandi dimensioni hanno dimostrato l'efficacia del ridimensionamento della lunghezza durante il post-training, tuttavia il suo potenziale nel pre-training rimane ancora poco esplorato. Presentiamo il Parallel Hidden Decoding Transformer (PHD-Transformer), un nuovo framework che consente un efficiente ridimensionamento della lunghezza durante il pre-training mantenendo al contempo l'efficienza nell'inferenza. Il PHD-Transformer raggiunge questo obiettivo attraverso una strategia innovativa di gestione della cache KV che distingue tra token originali e token di decodifica nascosti. Conservando solo la cache KV dei token originali per le dipendenze a lungo raggio e scartando immediatamente i token di decodifica nascosti dopo l'uso, il nostro approccio mantiene la stessa dimensione della cache KV del transformer tradizionale, consentendo un efficace ridimensionamento della lunghezza. Per migliorare ulteriormente le prestazioni, introduciamo due varianti ottimizzate: PHD-SWA utilizza l'attenzione a finestra scorrevole per preservare le dipendenze locali, mentre PHD-CSWA implementa l'attenzione a finestra scorrevole a blocchi per eliminare la crescita lineare nel tempo di pre-riempimento. Esperimenti estesi dimostrano miglioramenti consistenti su molteplici benchmark.
English
Recent advances in large language models have demonstrated the effectiveness
of length scaling during post-training, yet its potential in pre-training
remains underexplored. We present the Parallel Hidden Decoding Transformer
(PHD-Transformer), a novel framework that enables efficient length
scaling during pre-training while maintaining inference efficiency.
PHD-Transformer achieves this through an innovative KV cache
management strategy that distinguishes between original tokens and hidden
decoding tokens. By retaining only the KV cache of original tokens for
long-range dependencies while immediately discarding hidden decoding tokens
after use, our approach maintains the same KV cache size as the vanilla
transformer while enabling effective length scaling. To further enhance
performance, we introduce two optimized variants: PHD-SWA employs
sliding window attention to preserve local dependencies, while
PHD-CSWA implements chunk-wise sliding window attention to eliminate
linear growth in pre-filling time. Extensive experiments demonstrate consistent
improvements across multiple benchmarks.