ChatPaper.aiChatPaper

dParallel: dLLM을 위한 학습 가능한 병렬 디코딩

dParallel: Learnable Parallel Decoding for dLLMs

September 30, 2025
저자: Zigeng Chen, Gongfan Fang, Xinyin Ma, Ruonan Yu, Xinchao Wang
cs.AI

초록

확산 기반 대형 언어 모델(dLLMs)은 최근 연구 커뮤니티에서 상당한 주목을 받고 있으며, 이는 병렬 토큰 예측과 더 낮은 추론 지연 시간을 제공함으로써 자기회귀 생성 방식의 유망한 대안으로 여겨지고 있습니다. 그러나 이러한 병렬 디코딩 잠재력은 아직 크게 탐구되지 않았는데, 기존의 오픈소스 모델들은 여전히 성능을 보장하기 위해 거의 토큰 길이에 가까운 디코딩 단계를 필요로 합니다. 이를 해결하기 위해, 우리는 dParallel이라는 간단하면서도 효과적인 방법을 소개합니다. 이 방법은 dLLMs의 내재된 병렬성을 활용하여 빠른 샘플링을 가능하게 합니다. 우리는 병렬 디코딩의 주요 병목 현상이 마스킹된 토큰에 대한 순차적인 확실성 수렴에서 비롯된다는 것을 발견했습니다. 이러한 통찰을 바탕으로, 우리는 핵심 접근 방식인 확실성 강제 증류(certainty-forcing distillation)를 도입했습니다. 이는 모델이 원래의 샘플링 궤적을 따르도록 하면서도 마스킹된 토큰에 대해 더 빠르고 병렬적으로 높은 확실성을 달성하도록 강제하는 새로운 훈련 전략입니다. 다양한 벤치마크에서의 광범위한 실험을 통해, 우리의 방법이 성능을 유지하면서도 디코딩 단계 수를 극적으로 줄일 수 있음을 입증했습니다. LLaDA-8B-Instruct 모델에 적용했을 때, dParallel은 GSM8K에서 디코딩 단계를 256에서 30으로 줄여 8.5배의 속도 향상을 달성했으며 성능 저하 없이 이를 유지했습니다. MBPP 벤치마크에서는 디코딩 단계를 256에서 24로 줄여 10.5배의 속도 향상을 이루었고 정확도를 유지했습니다. 우리의 코드는 https://github.com/czg1225/dParallel에서 확인할 수 있습니다.
English
Diffusion large language models (dLLMs) have recently drawn considerable attention within the research community as a promising alternative to autoregressive generation, offering parallel token prediction and lower inference latency. Yet, their parallel decoding potential remains largely underexplored, as existing open-source models still require nearly token-length decoding steps to ensure performance. To address this, we introduce dParallel, a simple and effective method that unlocks the inherent parallelism of dLLMs for fast sampling. We identify that the key bottleneck to parallel decoding arises from the sequential certainty convergence for masked tokens. Building on this insight, we introduce the core of our approach: certainty-forcing distillation, a novel training strategy that distills the model to follow its original sampling trajectories while enforcing it to achieve high certainty on masked tokens more rapidly and in parallel. Extensive experiments across various benchmarks demonstrate that our method can dramatically reduce the number of decoding steps while maintaining performance. When applied to the LLaDA-8B-Instruct model, dParallel reduces decoding steps from 256 to 30 on GSM8K, achieving an 8.5x speedup without performance degradation. On the MBPP benchmark, it cuts decoding steps from 256 to 24, resulting in a 10.5x speedup while maintaining accuracy. Our code is available at https://github.com/czg1225/dParallel
PDF121October 1, 2025