通過模式插值來理解擴散模型中的幻覺
Understanding Hallucinations in Diffusion Models through Mode Interpolation
June 13, 2024
作者: Sumukh K Aithal, Pratyush Maini, Zachary C. Lipton, J. Zico Kolter
cs.AI
摘要
俗稱,基於擴散過程的影像生成模型常被說成展現「幻覺」,即在訓練數據中永遠不會出現的樣本。但這些幻覺從何而來呢?在本文中,我們研究了擴散模型中一種特定的失敗模式,我們稱之為模式插值。具體來說,我們發現擴散模型會平滑地在訓練集中的相鄰數據模式之間「插值」,以生成完全超出原始訓練分佈支持範圍的樣本;這種現象導致擴散模型生成從未存在於真實數據中的人工成果(即幻覺)。我們系統地研究了這種現象的原因和表現。通過對一維和二維高斯分佈的實驗,我們展示了擴散模型解碼器中不連續的損失地形如何導致一個區域,任何平滑近似都會引起這種幻覺。通過對具有各種形狀的人工數據集的實驗,我們展示了幻覺如何導致生成從未存在的形狀組合。最後,我們展示了擴散模型實際上知道何時超出支持範圍並產生幻覺。這是由生成樣本朝向最後幾個反向採樣過程的高變異性所捕捉的。通過使用一個簡單的指標來捕捉這種變異性,我們可以在生成時消除超過95%的幻覺,同時保留96%的支持內樣本。我們通過在MNIST和二維高斯數據集上進行實驗,展示了這種幻覺(及其消除)對於合成數據上遞歸訓練的崩潰(和穩定)的影響。我們在https://github.com/locuslab/diffusion-model-hallucination 上發布了我們的代碼。
English
Colloquially speaking, image generation models based upon diffusion processes
are frequently said to exhibit "hallucinations," samples that could never occur
in the training data. But where do such hallucinations come from? In this
paper, we study a particular failure mode in diffusion models, which we term
mode interpolation. Specifically, we find that diffusion models smoothly
"interpolate" between nearby data modes in the training set, to generate
samples that are completely outside the support of the original training
distribution; this phenomenon leads diffusion models to generate artifacts that
never existed in real data (i.e., hallucinations). We systematically study the
reasons for, and the manifestation of this phenomenon. Through experiments on
1D and 2D Gaussians, we show how a discontinuous loss landscape in the
diffusion model's decoder leads to a region where any smooth approximation will
cause such hallucinations. Through experiments on artificial datasets with
various shapes, we show how hallucination leads to the generation of
combinations of shapes that never existed. Finally, we show that diffusion
models in fact know when they go out of support and hallucinate. This is
captured by the high variance in the trajectory of the generated sample towards
the final few backward sampling process. Using a simple metric to capture this
variance, we can remove over 95% of hallucinations at generation time while
retaining 96% of in-support samples. We conclude our exploration by showing the
implications of such hallucination (and its removal) on the collapse (and
stabilization) of recursive training on synthetic data with experiments on
MNIST and 2D Gaussians dataset. We release our code at
https://github.com/locuslab/diffusion-model-hallucination.