Scalable Diffusion Models with Transformers
摘要
基于transformer结构构建了一类新的扩散模型,DiT。即使增加transformer的深度/宽度或者增加输入tokens的数量,依旧能够保持更低的FID。最大的DiT-XL/2模型在ImageNet 512×512和256×256基准上性能优于所有先前的扩散模型,在后者上实现了2.27的FID。
方法
预备知识
1.扩散模型
前向加噪过程:对真实数据逐步添加噪声:
对扩散模型进行训练,学习正向加噪过程的反向过程:
反向过程模型通过 $x_{0}$ 对数似然的变分下界进行训练:
模型训练的损失函数为:
用 $\mathcal{L}_{simple}$ 训练 $\epsilon_{\theta}$ ,用完整的 $\mathcal{L}$ 训练 $\sum_{\theta}$ 。只要 $p_{\theta}$ 训练确定了,新的图像就能够通过初始化 $x_{t_{\mathrm{max}}}\sim\mathcal{N}(0,\mathbf{I})$ 采样得到,并采样 $x_{t-1}\sim p_\theta(x_{t-1}|x_t)$ 。
2.无分类器引导
条件扩散模型采用额外的信息作为输入,例如一个分类标签 $c$ 。在这种情况下,反向过程变成了 $p_\theta(x_{t-1} | x_t, c)$ 。在这种设定下,无分类引导能够用来鼓励采样过程找到一个 $x$ 使得 $\log p(c|x)$ 很大。根据贝叶斯规则,$\log p(c|x)\propto\log p(x|c)-\log p(x)$ ,因此 $\nabla_x\log p(c|x)\propto\nabla_x\log p(x|c)-\nabla_x\log p(x)$ 。通过将扩散模型的输出解释为得分函数,可以通过以下方式引导DDPM采样过程对具有高 $p(x|c)$ 的 $x$ 进行采样:$\hat{\epsilon}_\theta(x_t,c)=\epsilon_\theta(x_t,\emptyset)+s\cdot \nabla_{x}\log p(x|c)\propto\epsilon_{\theta}(x_{t},\emptyset)+s\cdot(\epsilon_{\theta}(x_{t},c)-\epsilon_{\theta}(x_{t},\emptyset))$ 。在训练的时候随机丢弃 $c$ 并替换为可学习的 “null” embedding $\emptyset$ 。
3.潜在扩散模型
在高分辨率像素空间中直接训练扩散模型计算成本很高。潜在扩散模型训练了一个autoencoder将图像压缩到更小的空间表示中,然后在潜在表示上训练扩散模型。通过生成潜在表示,可以采用解码器恢复出原图像。
Diffusion Transformer 设计空间
1.Patchify
DiT的输入是一个空间表示 $z$ 。DiT的第一层是patchify,将空间输入转换为 $T$ 个tokens的序列,每个token的维度为 $d$ 。每一个patch线性嵌入到输入中。然后对所有的输入tokens应用标准ViT的基于频率的位置嵌入。tokens的数量由patch size的超参数 $p$ 决定。
2.DiT block设计
输入tokens经过transformer blocks序列处理。除了添加噪声的图像输入,diffusion models也有额外的条件输入,如噪声时间步 $t$ ,分类标签 $c$ 及自然语言等。因此设计了四种transformer blocks的变体来处理不同的条件输入。
In-context 条件
将 $t$ 和 $c$ 的向量嵌入添加在输入序列中。采用标准的ViT blocks。在最后一个block一处条件tokens。
Cross-attention block
将 $t$ 和 $c$ 的嵌入连接为长度为2的序列,与图像token序列分开。transformer block在multi-head self-attention block后添加了额外的multi-head cross-attention层。
Adaptive layer norm(adaLN)block
将标准的layer norm层替换为adaLN。不直接学习逐维度的尺度和偏移参数 $ \gamma $ 和 $\beta$ ,而是从 $t$ 和 $c$ 的嵌入向量之和中进行回归
adaLN-Zero block
除了对 $\gamma$ 和 $\beta$ 进行回归外,还对DiT block内任何残差连接之前应用的维度缩放参数 $\alpha$ 进行回归。初始化MLP以输出所有 $\alpha$ 的零向量;这将整个DiT block初始化为identity function。
3.各种不同规格的DiT模型
4.Transformer decoder
再最后一个DiT block后面,需要将图像tokens解码输出维噪声预测以及对角协方差预测。应用最后一个layer norm和线性decoder将每一个token解码为$p\times p\times 2C$ 的张量。最终将解码的tokends重新恢复成原来的尺寸。