簡略化・一般化された離散データ向けマスク拡散モデル
Simplified and Generalized Masked Diffusion for Discrete Data
June 6, 2024
著者: Jiaxin Shi, Kehang Han, Zhe Wang, Arnaud Doucet, Michalis K. Titsias
cs.AI
要旨
マスクド(または吸収型)拡散モデルは、離散データの生成モデリングにおける自己回帰モデルの代替として積極的に研究されています。しかし、この分野の既存研究は、不必要に複雑なモデル定式化や異なる視点間の関係性の不明瞭さに阻まれており、最適でないパラメータ設定、訓練目的関数、およびこれらの問題に対処するためのアドホックな調整が行われてきました。本研究では、マスクド拡散モデルの真の可能性を引き出すためのシンプルで汎用的なフレームワークを提供することを目指します。マスクド拡散モデルの連続時間変分目的関数が、クロスエントロピー損失の単純な重み付き積分であることを示します。また、本フレームワークにより、状態依存型マスキングスケジュールを用いた一般化されたマスクド拡散モデルの訓練が可能となります。OpenWebTextで訓練したモデルは、GPT-2規模の従来の拡散言語モデルをパープレキシティの点で上回り、5つのゼロショット言語モデリングタスクのうち4つで優れた性能を示しました。さらに、本モデルはピクセルレベルの画像モデリングにおいて従来の離散拡散モデルを大幅に上回り、CIFAR-10で2.78、ImageNet 64×64で3.42ビット/次元を達成し、同規模の自己回帰モデルと同等またはそれ以上の性能を示しました。
English
Masked (or absorbing) diffusion is actively explored as an alternative to
autoregressive models for generative modeling of discrete data. However,
existing work in this area has been hindered by unnecessarily complex model
formulations and unclear relationships between different perspectives, leading
to suboptimal parameterization, training objectives, and ad hoc adjustments to
counteract these issues. In this work, we aim to provide a simple and general
framework that unlocks the full potential of masked diffusion models. We show
that the continuous-time variational objective of masked diffusion models is a
simple weighted integral of cross-entropy losses. Our framework also enables
training generalized masked diffusion models with state-dependent masking
schedules. When evaluated by perplexity, our models trained on OpenWebText
surpass prior diffusion language models at GPT-2 scale and demonstrate superior
performance on 4 out of 5 zero-shot language modeling tasks. Furthermore, our
models vastly outperform previous discrete diffusion models on pixel-level
image modeling, achieving 2.78~(CIFAR-10) and 3.42 (ImageNet 64times64) bits
per dimension that are comparable or better than autoregressive models of
similar sizes.