ChatPaper.aiChatPaper

離散擴散中攤提序列蒙地卡羅的對比分布匹配

Contrastive Distribution Matching for Amortized Sequential Monte Carlo in Discrete Diffusion

May 22, 2026
作者: Jaihoon Kim, Taehoon Yoon, Prin Phunyaphibarn, Seungjun Kim, Morteza Mardani, Minhyuk Sung
cs.AI

摘要

離散擴散模型已成為生成結構化類別資料的強大框架。然而,如何有效從獎勵傾斜分佈中進行取樣仍是一項基本挑戰。儘管扭曲序列蒙特卡羅方法(Twisted Sequential Monte Carlo, SMC)能為此任務提供漸進精確的近似解,但在離散狀態空間中估計最優扭曲函數仍需耗費大量的蒙地卡羅近似計算,導致推論階段出現嚴重的計算瓶頸。為克服此限制,我們提出對比分佈匹配(Contrastive Distribution Matching, CDM)——一個透過正負樣本學習參數化扭曲函數,從而將SMC推論成本分攤化的新框架。為實現高效訓練,我們重新設計梯度估計器,使其能利用離散擴散模型中的封閉形式前向核函數。在實務中,評估我們所學得的扭曲函數僅會產生不到5%的額外計算開銷(相較於基礎模型的單次前向傳遞)。透過廣泛的實證評估,我們證實CDM在匹配的實際執行時間下持續優於現有基準方法。我們在多樣化的應用場景中驗證了本方法的有效性與通用性,包括有毒文本生成、調控性DNA序列設計、蛋白質可設計性,以及擴散大語言模型對齊。
English
Discrete diffusion models have emerged as powerful frameworks for generating structured categorical data. However, efficiently sampling from reward-tilted distributions remains a fundamental challenge. While Twisted Sequential Monte Carlo (SMC) offers asymptotic exactness for this task, estimating the optimal twist function in discrete state spaces necessitates costly Monte Carlo approximations, resulting a severe computational bottleneck at inference. To overcome this limitation, we introduce Contrastive Distribution Matching (CDM), a novel framework that amortizes the cost of SMC inference by learning a parameterized twist function via positive and negative samples. For efficient training, we reformulate the gradient estimator to leverage the closed-form forward kernels of discrete diffusion models. In practice, evaluating our learned twist function incurs less than 5% additional computational overhead compared to a single forward pass of the base model. Through extensive empirical evaluations, we demonstrate that CDM consistently outperforms existing baselines under matched wall-clock time. We validate the effectiveness and versatility of our approach across a diverse range of applications, including toxic text generation, regulatory DNA sequence design, protein designability, and diffusion large language model alignment.