DisCo-Diff: 離散潜在変数による連続拡散モデルの強化
DisCo-Diff: Enhancing Continuous Diffusion Models with Discrete Latents
July 3, 2024
著者: Yilun Xu, Gabriele Corso, Tommi Jaakkola, Arash Vahdat, Karsten Kreis
cs.AI
要旨
拡散モデル(DMs)は生成学習に革命をもたらしました。これらのモデルは、データを単純なガウス分布にエンコードするために拡散プロセスを利用します。しかし、複雑で潜在的に多峰性を持つデータ分布を単一の連続的なガウス分布にエンコードすることは、不必要に困難な学習問題を引き起こすと考えられます。本論文では、この課題を簡素化するために、補完的な離散潜在変数を導入したDiscrete-Continuous Latent Variable Diffusion Models(DisCo-Diff)を提案します。我々は、エンコーダによって推論される学習可能な離散潜在変数をDMsに追加し、DMとエンコーダをエンドツーエンドで学習します。DisCo-Diffは事前学習済みネットワークに依存しないため、フレームワークとして普遍的に適用可能です。離散潜在変数を導入することで、DMの複雑なノイズからデータへのマッピングを学習する際の曲率が低減され、学習が大幅に簡素化されます。さらに、オートリグレッシブトランスフォーマーを用いて離散潜在変数の分布をモデル化しますが、DisCo-Diffでは少数の離散変数と小さなコードブックしか必要としないため、このステップは簡単です。我々は、DisCo-Diffをトイデータ、いくつかの画像合成タスク、および分子ドッキングで検証し、離散潜在変数を導入することで一貫してモデルの性能が向上することを確認しました。例えば、DisCo-DiffはODEサンプラーを用いて、クラス条件付きImageNet-64/128データセットにおいて最先端のFIDスコアを達成しました。
English
Diffusion models (DMs) have revolutionized generative learning. They utilize
a diffusion process to encode data into a simple Gaussian distribution.
However, encoding a complex, potentially multimodal data distribution into a
single continuous Gaussian distribution arguably represents an unnecessarily
challenging learning problem. We propose Discrete-Continuous Latent Variable
Diffusion Models (DisCo-Diff) to simplify this task by introducing
complementary discrete latent variables. We augment DMs with learnable discrete
latents, inferred with an encoder, and train DM and encoder end-to-end.
DisCo-Diff does not rely on pre-trained networks, making the framework
universally applicable. The discrete latents significantly simplify learning
the DM's complex noise-to-data mapping by reducing the curvature of the DM's
generative ODE. An additional autoregressive transformer models the
distribution of the discrete latents, a simple step because DisCo-Diff requires
only few discrete variables with small codebooks. We validate DisCo-Diff on toy
data, several image synthesis tasks as well as molecular docking, and find that
introducing discrete latents consistently improves model performance. For
example, DisCo-Diff achieves state-of-the-art FID scores on class-conditioned
ImageNet-64/128 datasets with ODE sampler.Summary
AI-Generated Summary