一文详解扩散模型:DDPM
作者:京东零售 刘岩
扩散模型讲解
前沿
人工智能生成内容(AI Generated Content,AIGC)近年来成为了非常前沿的一个研究方向,生成模型目前有四个流派,分别是生成对抗网络(Generative Adversarial Models,GAN),变分自编码器(Variance Auto-Encoder,VAE),标准化流模型(Normalization Flow, NF)以及这里要介绍的扩散模型(Diffusion Models,DM)。扩散模型是受到热力学中的一个分支,它的思想来源是非平衡热力学(Non-equilibrium thermodynamics)。扩散模型的算法理论基础是通过变分推断(Variational Inference)训练参数化的马尔可夫链(Markov Chain),它在许多任务上展现了超过 GAN 等其它生成模型的效果,例如最近非常火热的 OpenAI 的 DALL-E 2,Stability.ai 的 Stable Diffusion 等。这些效果惊艳的模型扩散模型的理论基础便是我们这里要介绍的提出扩散模型的文章[1]和非常重要的 DDPM[2],扩散模型的实现并不复杂,但其背后的数学原理却非常丰富。在这里我会介绍这些重要的数学原理,但省去了这些公式的推导计算,如果你对这些推导感兴趣,可以学习参考文献[4,5,11]的相关内容。我在这里主要以一个相对简单的角度来讲解扩散模型,帮助你快速入门这个非常重要的生成算法。
1. 背景知识: 生成模型
目前生成模型主要有图 1 所示的四类。其中 GAN 的原理是通过判别器和生成器的互相博弈来让生成器生成足以以假乱真的图像。VAE 的原理是通过一个编码器将输入图像编码成特征向量,它用来学习高斯分布的均值和方差,而解码器则可以将特征向量转化为生成图像,它侧重于学习生成能力。流模型是从一个简单的分布开始,通过一系列可逆的转换函数将分布转化成目标分布。扩散模型先通过正向过程将噪声逐渐加入到数据中,然后通过反向过程预测每一步加入的噪声,通过将噪声去掉的方式逐渐还原得到无噪声的图像,扩散模型本质上是一个马尔可夫架构,只是其中训练过程用到了深度学习的 BP,但它更属于数学层面的创新。这也就是为什么很多计算机的同学看扩散模型相关的论文会如此费力。
图 1:生成模型的四种类型 [4]
扩散模型中最重要的思想根基是马尔可夫链,它的一个关键性质是平稳性。即如果一个概率随时间变化,那么再马尔可夫链的作用下,它会趋向于某种平稳分布,时间越长,分布越平稳。如图 2 所示,当你向一滴水中滴入一滴颜料时,无论你滴在什么位置,只要时间足够长,最终颜料都会均匀的分布在水溶液中。这也就是扩散模型的前向过程。
图 2:颜料分子在水溶液中的扩散过程
如果我们能够在扩散的过程颜料分子的位置、移动速度、方向等移动属性。那么也可以根据正向过程的保存的移动属性从一杯被溶解了颜料的水中反推颜料的滴入位置。这边是扩散模型的反向过程。记录移动属性的快照便是我们要训练的模型。
2. 扩散模型
在这一部分我们将集中介绍扩散模型的数学原理以及推导的几个重要性质,因为推导过程涉及大量的数学知识但是对理解扩散模型本身思想并无太大帮助,所以这里我会省去推导的过程而直接给出结论。但是我也会给出推导过程的出处,对其中的推导过程比较感兴趣的请自行查看。
2.1 计算原理
扩散模型简单的讲就是通过神经网络学习从纯噪声数据逐渐对数据进行去噪的过程,它包含两个步骤,如图 3:
图 3:DDPM 的前向加噪和后向去噪过程
2.1.1 前向过程
2.1.2 后向过程
2.1.3 目标函数
那么问题来了,我们究竟使用什么样的优化目标才能比较好的预测高斯噪声的分布呢?一个比较复杂的方式是使用变分自编码器的最大化证据下界(Evidence Lower Bound, ELBO)的思想来推导,如式(6),推导详细过程见论文[11]的式(47)到式(58),这里主要用到了贝叶斯定理和琴生不等式。
式(6)的推导细节并不重要,我们需要重点关注的是它的最终等式的三个组成部分,下面我们分别介绍它们:
图 4:扩散模型的去噪匹配项在每一步都要拟合噪音的真实后验分布和估计分布
真实后验分布可以使用贝叶斯定理进行推导,最终结果如式(8),推导过程见论文[11]的式(71)到式(84)。
2.1.4 模型训练
虽然上面我们介绍了很多内容,并给出了大量公式,但得益于推导出的几个重要性质,扩散模型的训练并不复杂,它的训练伪代码见算法 1。
2.1.5 样本生成
2.2 算法实现
2.2.1 模型结构
DDPM 在预测施加的噪声时,它的输入是施加噪声之后的图像,预测内容是和输入图像相同尺寸的噪声,所以它可以看做一个 Img2Img 的任务。DDPM 选择了 U-Net[9]作为噪声预测的模型结构。U-Net 是一个 U 形的网络结构,它由编码器,解码器以及编码器和解码器之间的跨层连接(残差连接)组成。其中编码器将图像降采样成一个特征,解码器将这个特征上采样为目标噪声,跨层连接用于拼接编码器和解码器之间的特征。
图 5:U-Net 的网络结构
下面我们介绍 DDPM 的模型结构的重要组件。首先在 U-Net 的卷积部分,DDPM 使用了宽残差网络(Wide Residual Network,WRN)[12]作为核心结构,WRN 是一个比标准残差网络层数更少,但是通道数更多的网络结构。也有作者复现发现 ConvNeXT 作为基础结构会取得非常显著的效果提升[13,14]。这里我们可以根据训练资源灵活的调整卷积结构以及具体的层数等超参。因为我们在扩散过程的整个流程中都共享同一套参数,为了区分不同的时间片,作者借鉴了 Transformer [15]的位置编码的思想,采用了正弦位置嵌入对时间进行了编码,这使得模型在预测噪声时知道它预测的是批次中分别是哪个时间片添加的噪声。在卷积层之间,DDPM 添加了一个注意力层。这里我们可以使用 Transformer 中提出的自注意力机制或是多头自注意力机制。[13]则提出了一个线性注意力机制的模块,它的特点是消耗的时间以及占用的内存和序列长度是线性相关的,对比传统注意力机制的平方相关要高效很多。在进行归一化时,DDPM 选择了组归一化(Group Normalization,GN)[16]。最后,对于 U-Net 中的降采样和上采样操作,DDPM 分别选择了步长为 2 的卷积以及反卷积。
确定了这些组件,我们便可以搭建用于 DDPM 的 U-Net 的模型了。从第 2.1 节的介绍我们知道,模型的输入为形状为(batch_size, num_channels, height, width)的噪声图像和形状为(batch_size,1)的噪声水平,返回的是形状为(batch_size, num_channels, height, width)的预测噪声,我们搭建的用于噪声预测的模型结构如下:
首先在噪声图像\( \boldsymbol x_0\)上应用卷积层,并为噪声水平计算时间嵌入;
接下来是降采样阶段。采用的模型结构依次是两个卷积(WRNS 或是 ConvNeXT)+GN+Attention+降采样层;
在网络的最中间,依次是卷积层+Attention+卷积层;
接下来是上采样阶段。它首先会使用 Short-cut 拼接来自降采样中同样尺寸的卷积,再之后是两个卷积+GN+Attention+上采样层。
最后是使用 WRNS 或是 ConvNeXT 作为输出层的卷积。
U-Net 类的 forword 函数如下面代码片段所示,完整的实现代码参照[3]。
2.2.2 前向加噪
图 6:一张图依次经过 0 次,50 次,100 次,150 次以及 199 次加噪后的效果图
根据式(14)我们知道,扩散模型的损失函数计算的是两张图像的相似性,因此我们可以选择使用回归算法的所有损失函数,以 MSE 为例,前向过程的核心代码如下面代码片段。
2.2.3 样本生成
根据 2.1.5 节介绍的样本生成流程,它的核心代码片段所示,关于这段代码的讲解我通过注释添加到了代码片段中。
最后我们看下在人脸图像数据集下训练的模型,一批随机噪声经过逐渐去噪变成人脸图像的示例。
图 7:扩散模型由随机噪声通过去噪逐渐生成人脸图像
3. 总结
这里我们以 DDPM 为例介绍了另一个派系的生成算法:扩散模型。扩散模型是一个基于马尔可夫链的数学模型,它通过预测每个时间片添加的噪声来进行模型的训练。作为近日来引发热烈讨论的 ControlNet, Stable Diffusion 等模型的底层算法,我们十分有必要对其有所了解。DDPM 的实现并不复杂,这得益于大量数学界大佬通过大量的数学推导将整个扩散过程和反向去噪过程进行了精彩的化简,这才有了 DDPM 的大道至简的实现。DDPM 作为一个扩散模型的基石算法,它有着很多早期算法的共同问题:
采样速度慢:DDPM 的去噪是从时刻到时刻的一个完整的马尔可夫链的计算,尤其是 DDPM 还需要一个比较大的才能保证比较好的效果,这就导致了 DDPM 的采样过程注定是非常慢的;
生成效果差:DDPM 的效果并不能说是非常好,尤其是对于高分辨率图像的生成。这一方面是因为它的计算速度限制了它扩展到更大的模型;另一方面它的设计还有一些问题,例如逐像素的计算损失并使用相同权值而忽略图像中的主体并不是非常好的策略。
内容不可控:我们可以看出,DDPM 生成的内容完全还是取决于它的训练集。它并没有引入一些先验条件,因此并不能通过控制图像中的细节来生成我们制定的内容。
我们现在已经知道,DDPM 的这些问题已大幅得到改善,现在基于扩散模型生成的图像已经达到甚至超过人类多数的画师的效果,我也会在之后逐渐给出这些优化方案的讲解。
Reference
[1] Sohl-Dickstein, Jascha, et al. "Deep unsupervised learning using nonequilibrium thermodynamics." International Conference on Machine Learning. PMLR, 2015.
[2] Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851.
[3] https://huggingface.co/blog/annotated-diffusion
[4] https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#simplification
[5] https://openai.com/blog/generative-models/
[6] Nichol, Alexander Quinn, and Prafulla Dhariwal. "Improved denoising diffusion probabilistic models." International Conference on Machine Learning. PMLR, 2021.
[7] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).
[8] Hinton, Geoffrey E., and Ruslan R. Salakhutdinov. "Reducing the dimensionality of data with neural networks." science 313.5786 (2006): 504-507.
[9] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation[C]//International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015: 234-241.
[10] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for semantic segmentation." Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.
[11] Luo, Calvin. "Understanding diffusion models: A unified perspective." arXiv preprint arXiv:2208.11970 (2022).
[12] Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks." arXiv preprint arXiv:1605.07146 (2016).
[13] https://github.com/lucidrains/denoising-diffusion-pytorch
[14] Liu, Zhuang, et al. "A convnet for the 2020s." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
[15] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).
[16] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018.
版权声明: 本文为 InfoQ 作者【京东科技开发者】的原创文章。
原文链接:【http://xie.infoq.cn/article/a598f10f348393c61d898d0df】。文章转载请联系作者。
评论