写点什么

深度学习进阶篇 [8]:对抗神经网络 GAN 基本概念简介、纳什均衡、生成器判别器、解码编码器详解以及 GAN 应用场景

  • 2023-06-01
    浙江
  • 本文字数:4312 字

    阅读完需:约 14 分钟

深度学习进阶篇[8]:对抗神经网络GAN基本概念简介、纳什均衡、生成器判别器、解码编码器详解以及GAN应用场景

深度学习进阶篇[8]:对抗神经网络 GAN 基本概念简介、纳什均衡、生成器判别器、解码编码器详解以及 GAN 应用场景

对抗神经网络 GAN 基本概念简介:generative adversarial network

1.博弈论

博弈论可以被认为是两个或多个理性的代理人或玩家之间相互作用的模型。


理性这个关键字,因为它是博弈论的基础。我们可以简单地把理性称为一种理解,即每个行为人都知道所有其他行为人都和他/她一样理性,拥有相同的理解和知识水平。同时,理性指的是,考虑到其他行为人的行为,行为人总是倾向于更高的报酬/回报。


既然我们已经知道了理性意味着什么,让我们来看看与博弈论相关的其他一些关键词:


  • 游戏:一般来说,游戏是由一组玩家,行动/策略和最终收益组成。例如:拍卖、象棋、政治等。

  • 玩家:玩家是参与任何游戏的理性实体。例如:在拍卖会的投标人、石头剪刀布的玩家、参加选举的政治家等。

  • 收益:收益是所有玩家在获得特定结果时所获得的奖励。它可以是正的,也可以是负的。正如我们之前所讨论的,每个代理都是自私的,并且想要最大化他们的收益。

2.纳什均衡

纳什均衡(或者纳什平衡),Nash equilibrium ,又称为非合作博弈均衡,是人工智能博弈论方法的“基石”。


所谓纳什均衡,指的是参与者的一种策略组合,在该策略上,任何参与人单独改变策略都不会得到好处,即每个人的策略都是对其他人的策略的最优反应。换句话说,如果在一个策略组合上,当所有其他人都不改变策略时,没有人会改变自己的策略,则该策略组合就是一个纳什均衡。


经典的例子就是囚徒困境


背景:一个案子的两个嫌疑犯 A 和 B 被警官分开审讯,所以 A 和 B 没有机会进行串供的;


奖惩:警官分别告诉 A 和 B,如果都不招供,则各判 3 年;如果两人均招供,均判 5 年;如果你招供、而对方不招供,则你判 1 年,对方 10 年。


结果:A 和 B 都选择招供,各判 5 年,这个便是此时的纳什均衡。


从奖惩说明看都不招供才是最优解,判刑最少。其实并不是这样,A 和 B 无法沟通,于是从各自的利益角度出发:


嫌疑犯 A 想法:


  • 如果 B 招供,如果我招供只判 5 年,不招供的话就判 10 年;

  • 如果 B 不招供,如果我招供只判 1 年,不招供的话就判 3 年;


所以无论 B 是否招供,A 只要招供了,对 A 而言是最优的策略。


同上,嫌疑犯 B 想法也是相同的,都依据各自的理性而选择招供,这种情况就被称为纳什均衡点。

3.GAN 生成器的输入为什么是噪声

GAN 生成器 Generator 的输入是随机噪声,目的是每次生成不同的图片。但如果完全随机,就不知道生成的图像有什么特征,结果就会不可控,因此通常从一个先验的随机分布产生噪声。常用的随机分布:


  • 高斯分布:连续变量中最广泛使用的概率分布;

  • 均匀分布:连续变量 x 的一种简单分布。


引入随机噪声使得生成的图片具有多样性,比如下图不同的噪声 z 可以产生不同的数字:


4.生成器 Generator

生成器 G 是一个生成图片的网络,可以采用多层感知机、卷积网络、自编码器等。它接收一个随机的噪声 z,通过这个噪声生成图片,记做 G(z)。通过下图模型结构讲解生成器如何一步步将噪声生成一张图片:



1)输入:100 维的向量;


2)经过两个全连接层 Fc1 和 Fc2、一个 Resize,将噪声向量放大,得到 128 个 7*7 大小的特征图;


3)进行上采样,以扩大特征图,得到 128 个 14*14 大小的特征图;


4)经过第一个卷积 Conv1,得到 64 个 14*14 的特征图;


5)进行上采样,以扩大特征图,得到 64 个 28*28 大小的特征图;


6)经过第二个卷积 Conv2,将输入的噪声 Z 逐渐转化为 12828 的单通道图片输出,得到生成的手写数字。


Tips:全连接层作用:维度变换,变为高维,方便将噪声向量放大。因为全连接层计算量稍大,后序改进的 GAN 移除全连接层。


Tips:最后一层激活函数通常使用 tanh():既起到激活作用,又起到归一作用,将生成器的输出归一化至[-1,1],作为判别器的输入。也使 GAN 的训练更稳定,收敛速度更快,生成质量确实更高。

5.判别器 Discriminator

判别器 D 的输入为真实图像和生成器生成的图像,其目的是将生成的图像从真实图像中尽可能的分辨出来。属于二分类问题,通过下图模型结构讲解判别器如何区分真假图片:


  • 输入:单通道图像,尺寸为 28*28 像素(非固定值,根据实际情况修改即可)。

  • 输出:二分类,样本是真或假。



1)输入:28281 像素的图像;


2)经过第一个卷积 conv1,得到 64 个 2626 的特征图,然后进行最大池化 pool1,得到 64 个 1313 的特征图;


3)经过第二个卷积 conv2,得到 128 个 1111 的特征图,然后进行最大池化 pool2,得到 128 个 55 的特征图;


4)通过 Resize 将多维输入一维化;


5)再经过两个全连接层 fc1 和 fc2,得到原始图像的向量表达;


6)最后通过 Sigmoid 激活函数,输出判别概率,即图片是真是假的二分类结果。

6.GAN 损失函数

在训练过程中,生成器 G(Generator)的目标就是尽量生成真实的图片去欺骗判别器 D(Discriminator)。而 D 的目标就是尽量把 G 生成的图片和真实的图片区分开。这样,G 和 D 构成了一个动态的“博弈过程”。


最后博弈的结果是什么?在最理想的状态下,G 可以生成足以“以假乱真”的图片 G(z)。对于 D 来说,它难以判定 G 生成的图片究竟是不是真实的,因此 D(G(z)) = 0.5。


用公式表示如下:


\begin{equation} \mathop{min}\limits_{G}\mathop{max}\limits_{D}V(D,G) = Ε_{x\sim p_{data}(x)} \left[\log D\left(x\right)\right]+Ε_{z\sim p_{z}(z)}\left[\log \left(1 - D\left(G\left(z\right)\right)\right)\right]\end{equation} \tag{1}


公式左边 V(D,G)表示生成图像和真实图像的差异度,采用二分类(真、假两个类别)的交叉熵损失函数。包含 minG 和 maxD 两部分:


表示固定生成器 G 训练判别器 D,通过最大化交叉熵损失 V(D,G)来更新判别器 D 的参数。D 的训练目标是正确区分真实图片 x 和生成图片 G(z),D 的鉴别能力越强,D(x)应该越大,右边第一项更大,D(G(x))应该越小,右边第二项更大。这时 V(D,G)会变大,因此式子对于 D 来说是求最大(maxD)。


表示固定判别器 D 训练生成器 G,生成器要在判别器最大化真、假图片交叉熵损失 V(D,G)的情况下,最小化这个交叉熵损失。此时右边只有第二项有用, G 希望自己生成的图片“越接近真实越好”,能够欺骗判别器,即 D(G(z))尽可能得大,这时 V(D, G)会变小。因此式子对于 G 来说是求最小(min_G)。


  • $$:表示真实图像;

  • :表示高斯分布的样本,即噪声;

  • D(x)代表 x 为真实图片的概率,如果为 1,就代表 100%是真实的图片,而输出为 0,就代表不可能是真实的图片。


等式的右边其实就是将等式左边的交叉商损失公式展开,并写成概率分布的期望形式。详细的推导请参见原论文Generative Adversarial Nets

7.模型训练

GAN 包含生成器 G 和判别器 D 两个网络,那么我们如何训练两个网络?



训练时先训练鉴别器 D 将真实图片打上真标签 1 和生成器 G 生成的假图片打上假标签 0,一同组成 batch 送入判别器 D,对判别器进行训练。计算 loss 时使判别器对真实图像输入的判别趋近于真,对生成的假图片的判别趋近于假。此过程中只更新判别器的参数,不更新生成器的参数。


然后再训练生成器 G 将高斯分布的噪声 z 送入生成器 G,将生成的假图片打上真标签 1 送入判别器 D。计算 loss 时使判别器对生成的假图片的判别趋近于真。此过程中只更新生成器的参数,不更新判别器的参数。


注意:训练初期,当 G 的生成效果很差时,D 会以高置信度来拒绝生成样本,因为它们与训练数据明显不同。因此,log(1−D(G(z)))饱和(即为常数,梯度为 0)。因此我们选择最大化 logD(G(z))而不是最小化 log(1−D(G(z)))来训练 G,和公示(1)右边第二项比较。

8 模型训练不稳定

GAN 训练不稳定的原因如下:


  • 不收敛:很难使两个模型 G 和 D 同时收敛;

  • 模式崩溃:生成器 G 生成单个或有限模式;

  • 慢速训练:生成器 G 的梯度消失。


训练 GAN 的时候,可以采取以下训练技巧:


1)生成器最后一层的激活函数用 tanh(),输出归一化至[-1, 1];


2)真实图像也归一化到[-1,1]之间;


3)学习率不要设置太大,初始 1e-4 可以参考,另外可以随着训练进行不断缩小学习率;


4)优化器尽量选择 Adam,因为 SGD 解决的是一个寻找最小值的问题,GAN 是一个博弈问题,使用 SGD 容易震荡;


5)避免使用 ReLU 和 MaxPool,减少稀疏梯度的可能性,可以使用 Leak Re LU 激活函数,下采样可以用 Average Pooling 或者 Convolution + stride 替代。上采样可以用 PixelShuffle, ConvTranspose2d + stride;


6)加噪声:在真实图像和生成图像中添加噪声,增加鉴别器训练难度,有利于提升稳定性;


7)如果有标签数据,尽量使用标签信息来训练;


8)标签平滑:如果真实图像的标签设置为 1,我们将它更改为一个较低的值,比如 0.9,避免鉴别器对其分类过于自信 。

9.编码器 Encoder

Encoder 目标是将输入序列编码成低维的向量表示或 embedding,映射函数如下:


\begin{equation}V\to R^{d}\end{equation} \tag{1}


即将输入 V 映射成 embedding ,如下图所示:



Encoder 一般是卷积神经网络,主要由卷积层,池化层和 BatchNormalization 层组成。卷积层负责获取图像局域特征,池化层对图像进行下采样并且将尺度不变特征传送到下一层,而 BN 主要对训练图像的分布归一化,加速学习。(Encoder 网络结构不局限于卷积神经网络)


以人脸编码为例,Encoder 将人脸图像压缩到短向量,这样短向量就包含了人脸图像的主要信息,例如该向量的元素可能表示人脸肤色、眉毛位置、眼睛大小等等。编码器学习不同人脸,那么它就能学习到人脸的共性:


10.解码器 Decoder

Decoder 目标是利用 Encoder 输出的 embedding,来解码关于图的结构信息。



输入是 Node Pair 的 embeddings,输出是一个实数,衡量了这两个 Node 在中的相似性,映射关系如下:


\begin{equation}R^{d} * R^{d}\to R^{+}\end{equation}. \tag{1}


Decoder 对缩小后的特征图像向量进行上采样,然后对上采样后的图像进行卷积处理,目的是完善物体的几何形状,弥补 Encoder 当中池化层将物体缩小造成的细节损失。


以人脸编码、解码为例,Encoder 对人脸进行编码之后,再用解码器 Decoder 学习人脸的特性,即由短向量恢复到人脸图像,如下图所示:


11.GAN 应用

一起来看看 GAN 有哪些有趣的应用:


  • 图像生成

  • 图像生成是生成模型的基本问题,GAN 相对先前的生成模型能够生成更高图像质量的图像。如生成逼真的人脸图像


  • 超分辨率

  • 将图像放大时,图片会变得模糊。使用 GAN 将 32*32 的图像扩展为 64*64 的真实图像,放大图像的同时提升图片的分辨率。



  • 图像修复

  • 将残缺的图像补全、也可以用于去除纹身、电视 logo、水印等。



  • 图像到图像的转换

  • 根据一幅图像生成生成另一幅风格不同图像,比如马变成斑马图、航拍地图变成地图


  • 风景动漫化

  • 将风景图转化为动漫效果


  • 漫画脸

  • 将人脸图生成卡通风格



  • 图像上色

  • 黑白影像上色


  • 文本转图像

  • 根据文字描述生成对应图像



GAN 的应用常用非常广泛,远远不止上述几种。


发布于: 刚刚阅读数: 4
用户头像

本博客将不定期更新关于NLP等领域相关知识 2022-01-06 加入

本博客将不定期更新关于机器学习、强化学习、数据挖掘以及NLP等领域相关知识,以及分享自己学习到的知识技能,感谢大家关注!

评论

发布
暂无评论
深度学习进阶篇[8]:对抗神经网络GAN基本概念简介、纳什均衡、生成器判别器、解码编码器详解以及GAN应用场景_人工智能_汀丶人工智能_InfoQ写作社区