通过模式插值理解扩散模型中的幻觉
Understanding Hallucinations in Diffusion Models through Mode Interpolation
June 13, 2024
作者: Sumukh K Aithal, Pratyush Maini, Zachary C. Lipton, J. Zico Kolter
cs.AI
摘要
通俗地说,基于扩散过程的图像生成模型经常被称为展示“幻觉”,即在训练数据中永远不会出现的样本。但这些幻觉从何而来呢?在本文中,我们研究了扩散模型中的一种特定失败模式,我们称之为模式插值。具体来说,我们发现扩散模型会在训练集中的相邻数据模式之间平滑地“插值”,以生成完全超出原始训练分布支持范围的样本;这种现象导致扩散模型生成从未存在于真实数据中的人工成果(即幻觉)。我们系统地研究了这种现象的原因和表现。通过对一维和二维高斯分布的实验,我们展示了扩散模型解码器中不连续的损失景观如何导致一个区域,在该区域中,任何平滑近似都会引起这种幻觉。通过对具有各种形状的人工数据集的实验,我们展示了幻觉如何导致生成从未存在的形状组合。最后,我们展示了扩散模型实际上知道何时超出支持范围并产生幻觉。这通过生成样本朝向最后几个反向采样过程的轨迹具有高方差来体现。通过使用一个简单的度量来捕捉这种方差,我们可以在生成时消除超过95%的幻觉,同时保留96%的支持内样本。我们通过在MNIST和二维高斯数据集上进行实验,展示了这种幻觉(及其消除)对合成数据上递归训练的崩溃(和稳定)的影响。我们在https://github.com/locuslab/diffusion-model-hallucination 上发布了我们的代码。
English
Colloquially speaking, image generation models based upon diffusion processes
are frequently said to exhibit "hallucinations," samples that could never occur
in the training data. But where do such hallucinations come from? In this
paper, we study a particular failure mode in diffusion models, which we term
mode interpolation. Specifically, we find that diffusion models smoothly
"interpolate" between nearby data modes in the training set, to generate
samples that are completely outside the support of the original training
distribution; this phenomenon leads diffusion models to generate artifacts that
never existed in real data (i.e., hallucinations). We systematically study the
reasons for, and the manifestation of this phenomenon. Through experiments on
1D and 2D Gaussians, we show how a discontinuous loss landscape in the
diffusion model's decoder leads to a region where any smooth approximation will
cause such hallucinations. Through experiments on artificial datasets with
various shapes, we show how hallucination leads to the generation of
combinations of shapes that never existed. Finally, we show that diffusion
models in fact know when they go out of support and hallucinate. This is
captured by the high variance in the trajectory of the generated sample towards
the final few backward sampling process. Using a simple metric to capture this
variance, we can remove over 95% of hallucinations at generation time while
retaining 96% of in-support samples. We conclude our exploration by showing the
implications of such hallucination (and its removal) on the collapse (and
stabilization) of recursive training on synthetic data with experiments on
MNIST and 2D Gaussians dataset. We release our code at
https://github.com/locuslab/diffusion-model-hallucination.