Prédiction multi-jetons efficace sans entraînement par sondage de l'espace d'embedding
Efficient Training-Free Multi-Token Prediction via Embedding-Space Probing
March 18, 2026
Auteurs: Raghavv Goel, Mukul Gagrani, Mingu Lee, Chris Lott
cs.AI
Résumé
Les grands modèles de langage (LLM) présentent des capacités latentes de prédiction multi-jetons (MTP) bien qu'ils soient entraînés uniquement pour la génération de jetons suivants. Nous proposons une approche MTP simple et sans entraînement qui sonde un LLM en utilisant des jetons de masque générés à la volée à partir de son espace d'embedding, permettant la prédiction parallèle de jetons futurs sans modifier les poids du modèle ni recourir à des modèles d'ébauche auxiliaires. Notre méthode construit un arbre spéculatif de jetons en échantillonnant les meilleurs candidats K à partir des logits des jetons de masque et applique une stratégie légère d'élagage pour conserver les suites à forte probabilité. Pendant le décodage, les prédictions candidates sont vérifiées en parallèle, ce qui génère une production sans perte tout en réduisant considérablement le nombre d'appels au modèle et en améliorant le débit de jetons. Sur divers benchmarks, notre MTP par sondage surpasse systématiquement les méthodes de référence sans entraînement existantes, augmentant la longueur d'acceptation d'environ 12\% sur LLaMA3 et de 8 à 12\% sur Qwen3, et obtenant des gains de débit allant jusqu'à 15-19\%. Enfin, nous fournissons des insights théoriques et des preuves empiriques montrant que les couches décodeurs alignent naturellement les représentations des jetons de masque avec les états des jetons suivants, permettant une prédiction multi-étapes précise sans réentraînement ni modèles auxiliaires.
English
Large language models (LLMs) exhibit latent multi-token prediction (MTP) capabilities despite being trained solely for next-token generation. We propose a simple, training-free MTP approach that probes an LLM using on-the-fly mask tokens drawn from its embedding space, enabling parallel prediction of future tokens without modifying model weights or relying on auxiliary draft models. Our method constructs a speculative token tree by sampling top-K candidates from mask-token logits and applies a lightweight pruning strategy to retain high-probability continuations. During decoding, candidate predictions are verified in parallel, resulting in lossless generation while substantially reducing the number of model calls and improving token throughput. Across benchmarks, our probing-based MTP consistently outperforms existing training-free baselines, increasing acceptance length by approximately 12\% on LLaMA3 and 8--12\% on Qwen3, and achieving throughput gains of up to 15--19\%. Finally, we provide theoretical insights and empirical evidence showing that decoder layers naturally align mask-token representations with next-token states, enabling accurate multi-step prediction without retraining or auxiliary models.