T-Stitch:利用軌跡拼接加速預訓練擴散模型中的取樣
T-Stitch: Accelerating Sampling in Pre-Trained Diffusion Models with Trajectory Stitching
February 21, 2024
作者: Zizheng Pan, Bohan Zhuang, De-An Huang, Weili Nie, Zhiding Yu, Chaowei Xiao, Jianfei Cai, Anima Anandkumar
cs.AI
摘要
對於高品質圖像生成,從擴散概率模型(DPMs)進行抽樣通常是昂貴的,通常需要多個步驟以及一個龐大的模型。本文介紹了一種名為Trajectory Stitching T-Stitch的抽樣技術,這是一種簡單而高效的技術,可以提高抽樣效率,並幾乎不會導致生成品質下降。T-Stitch不僅僅使用一個大型DPM進行整個抽樣軌跡,而是首先在初始步驟中利用一個較小的DPM作為較便宜的替代品,並在後續階段切換到較大的DPM。我們的關鍵見解是,不同的擴散模型在相同的訓練數據分佈下學習到類似的編碼,而較小的模型能夠在早期步驟生成良好的全局結構。大量實驗表明,T-Stitch無需訓練,通常適用於不同的架構,並且可以與大多數現有的快速抽樣技術相結合,實現靈活的速度和質量折衷。例如,在DiT-XL上,可以安全地將40%的早期時間步驟替換為速度快10倍的DiT-S,而在類條件ImageNet生成中不會降低性能。我們進一步展示,我們的方法不僅可以用作加速流行的預訓練穩定擴散(SD)模型的替代技術,還可以改善從公共模型庫中提取的風格化SD模型的提示對齊。代碼已在https://github.com/NVlabs/T-Stitch上發布。
English
Sampling from diffusion probabilistic models (DPMs) is often expensive for
high-quality image generation and typically requires many steps with a large
model. In this paper, we introduce sampling Trajectory Stitching T-Stitch, a
simple yet efficient technique to improve the sampling efficiency with little
or no generation degradation. Instead of solely using a large DPM for the
entire sampling trajectory, T-Stitch first leverages a smaller DPM in the
initial steps as a cheap drop-in replacement of the larger DPM and switches to
the larger DPM at a later stage. Our key insight is that different diffusion
models learn similar encodings under the same training data distribution and
smaller models are capable of generating good global structures in the early
steps. Extensive experiments demonstrate that T-Stitch is training-free,
generally applicable for different architectures, and complements most existing
fast sampling techniques with flexible speed and quality trade-offs. On DiT-XL,
for example, 40% of the early timesteps can be safely replaced with a 10x
faster DiT-S without performance drop on class-conditional ImageNet generation.
We further show that our method can also be used as a drop-in technique to not
only accelerate the popular pretrained stable diffusion (SD) models but also
improve the prompt alignment of stylized SD models from the public model zoo.
Code is released at https://github.com/NVlabs/T-Stitch