残差上下文扩散语言模型
Residual Context Diffusion Language Models
January 30, 2026
作者: Yuezhou Hu, Harman Singh, Monishwaran Maheswaran, Haocheng Xi, Coleman Hooper, Jintao Zhang, Aditya Tomar, Michael W. Mahoney, Sewon Min, Mehrdad Farajtabar, Kurt Keutzer, Amir Gholami, Chenfeng Xu
cs.AI
摘要
扩散大语言模型(dLLMs)因其能并行解码多个标记而成为纯自回归语言模型的有力替代方案。然而,当前最先进的块状dLLMs依赖"重掩码"机制,仅解码置信度最高的标记而丢弃其余标记,这实质上造成了计算资源浪费。我们证明回收被丢弃标记的计算结果具有显著价值,因为这些标记保留了有助于后续解码迭代的上下文信息。基于此,我们提出残差上下文扩散(RCD)模块,该模块可将废弃的标记表征转化为上下文残差,并将其注入下一去噪步骤。RCD采用解耦的双阶段训练流程,以规避反向传播相关的内存瓶颈。我们在长链思维推理(SDAR)和短链指令跟随(LLaDA)模型上验证了该方法,证明标准dLLM仅需约10亿标记即可高效转换为RCD范式。在广泛基准测试中,RCD以最小额外计算开销将前沿dLLMs的准确率稳定提升5-10个百分点。值得注意的是,在最富挑战性的AIME任务上,RCD使基线准确率近乎翻倍,并在同等精度水平下实现去噪步骤减少4-5倍。
English
Diffusion Large Language Models (dLLMs) have emerged as a promising alternative to purely autoregressive language models because they can decode multiple tokens in parallel. However, state-of-the-art block-wise dLLMs rely on a "remasking" mechanism that decodes only the most confident tokens and discards the rest, effectively wasting computation. We demonstrate that recycling computation from the discarded tokens is beneficial, as these tokens retain contextual information useful for subsequent decoding iterations. In light of this, we propose Residual Context Diffusion (RCD), a module that converts these discarded token representations into contextual residuals and injects them back for the next denoising step. RCD uses a decoupled two-stage training pipeline to bypass the memory bottlenecks associated with backpropagation. We validate our method on both long CoT reasoning (SDAR) and short CoT instruction following (LLaDA) models. We demonstrate that a standard dLLM can be efficiently converted to the RCD paradigm with merely ~1 billion tokens. RCD consistently improves frontier dLLMs by 5-10 points in accuracy with minimal extra computation overhead across a wide range of benchmarks. Notably, on the most challenging AIME tasks, RCD nearly doubles baseline accuracy and attains up to 4-5x fewer denoising steps at equivalent accuracy levels.