dParallel: dLLM向け学習可能な並列デコード手法
dParallel: Learnable Parallel Decoding for dLLMs
September 30, 2025
著者: Zigeng Chen, Gongfan Fang, Xinyin Ma, Ruonan Yu, Xinchao Wang
cs.AI
要旨
拡散型大規模言語モデル(dLLM)は、最近、自己回帰型生成の有望な代替手段として研究コミュニティで注目を集めており、並列トークン予測と低い推論遅延を提供します。しかし、その並列デコードの可能性はまだ十分に探求されておらず、既存のオープンソースモデルでは性能を確保するためにほぼトークン長のデコードステップが必要です。この問題に対処するため、我々はdParallelを導入します。これは、dLLMの内在する並列性を活用して高速サンプリングを実現するシンプルで効果的な手法です。並列デコードの主要なボトルネックが、マスクされたトークンの逐次的確実性収束にあることを特定しました。この洞察に基づき、我々のアプローチの中核となる「確実性強制蒸留」を導入します。これは、モデルが元のサンプリング軌跡を追従しつつ、マスクされたトークンに対してより迅速かつ並列に高い確実性を達成するよう訓練する新しい戦略です。様々なベンチマークでの広範な実験により、我々の手法が性能を維持しながらデコードステップ数を劇的に削減できることが示されました。LLaDA-8B-Instructモデルに適用した場合、dParallelはGSM8Kでのデコードステップを256から30に削減し、性能低下なしに8.5倍の高速化を実現しました。MBPPベンチマークでは、デコードステップを256から24に削減し、精度を維持しながら10.5倍の高速化を達成しました。我々のコードはhttps://github.com/czg1225/dParallelで公開されています。
English
Diffusion large language models (dLLMs) have recently drawn considerable
attention within the research community as a promising alternative to
autoregressive generation, offering parallel token prediction and lower
inference latency. Yet, their parallel decoding potential remains largely
underexplored, as existing open-source models still require nearly token-length
decoding steps to ensure performance. To address this, we introduce dParallel,
a simple and effective method that unlocks the inherent parallelism of dLLMs
for fast sampling. We identify that the key bottleneck to parallel decoding
arises from the sequential certainty convergence for masked tokens. Building on
this insight, we introduce the core of our approach: certainty-forcing
distillation, a novel training strategy that distills the model to follow its
original sampling trajectories while enforcing it to achieve high certainty on
masked tokens more rapidly and in parallel. Extensive experiments across
various benchmarks demonstrate that our method can dramatically reduce the
number of decoding steps while maintaining performance. When applied to the
LLaDA-8B-Instruct model, dParallel reduces decoding steps from 256 to 30 on
GSM8K, achieving an 8.5x speedup without performance degradation. On the MBPP
benchmark, it cuts decoding steps from 256 to 24, resulting in a 10.5x speedup
while maintaining accuracy. Our code is available at
https://github.com/czg1225/dParallel