추론 시간 스케일링을 통한 이산 확산 모델의 재마스킹
Remasking Discrete Diffusion Models with Inference-Time Scaling
March 1, 2025
저자: Guanghan Wang, Yair Schiff, Subham Sekhar Sahoo, Volodymyr Kuleshov
cs.AI
초록
확산 모델의 성공 요인 중 하나는 생성 과정에서 반복적으로 출력을 수정하는 능력, 즉 반복적 정제를 수행할 수 있는 데 있습니다. 그러나 현대의 마스킹된 이산 확산 모델은 이러한 능력이 부족합니다: 토큰이 생성되면, 오류가 발생하더라도 이를 다시 업데이트할 수 없습니다. 본 연구에서는 이러한 한계를 해결하기 위해 리마스킹 확산 모델(ReMDM) 샘플러를 소개합니다. 이 방법은 사전 학습된 마스킹 확산 모델에 원칙적으로 적용할 수 있으며, 사용자 정의 리마스킹 역과정을 가진 이산 확산 모델에서 유도되었습니다. 특히 흥미로운 점은, ReMDM이 이산 확산 모델에 추론 시 계산 규모 조정 기능을 부여한다는 것입니다. 샘플링 단계 수를 증가시킴으로써 ReMDM은 자기회귀 모델의 품질에 근접하는 자연어 출력을 생성할 수 있으며, 계산 예산이 제한된 경우에도 품질을 더 잘 유지합니다. ReMDM은 또한 이산화된 이미지에 대한 마스킹 확산 모델의 샘플 품질을 개선하고, 분자 설계와 같은 과학적 영역에서 확산 가이던스를 용이하게 하며, 기존의 마스킹 및 균일 잡음 확산에 비해 제어 가능성의 파레토 프론티어를 확장합니다. 프로젝트 페이지(https://remdm.github.io)에서 코드와 블로그 포스트를 제공합니다.
English
Part of the success of diffusion models stems from their ability to perform
iterative refinement, i.e., repeatedly correcting outputs during generation.
However, modern masked discrete diffusion lacks this capability: when a token
is generated, it cannot be updated again, even when it introduces an error.
Here, we address this limitation by introducing the remasking diffusion model
(ReMDM) sampler, a method that can be applied to pretrained masked diffusion
models in a principled way and that is derived from a discrete diffusion model
with a custom remasking backward process. Most interestingly, ReMDM endows
discrete diffusion with a form of inference-time compute scaling. By increasing
the number of sampling steps, ReMDM generates natural language outputs that
approach the quality of autoregressive models, whereas when the computation
budget is limited, ReMDM better maintains quality. ReMDM also improves sample
quality of masked diffusion models for discretized images, and in scientific
domains such as molecule design, ReMDM facilitates diffusion guidance and
pushes the Pareto frontier of controllability relative to classical masking and
uniform noise diffusion. We provide the code along with a blog post on the
project page: https://remdm.github.io.