PyTorch 深度学习实战 | Fashion MNIST 图片重建实战
01、操作流程
1)加载 Fashion MNIST 数据集
Fashion MNIST 是一个定位在比 MNIST 图片识别问题稍复杂的数据集,它的设定与 MNIST 几乎完全一样,包含了 10 类不同类型的衣服、鞋子、包等灰度图片,图片大小为,共 70000 张图片,其中 60000 张用于训练集,10000 张用于测试集,如图 1 所示,每行是一种类别图片。可以看到,Fashion MNIST 除了图片内容与 MNIST 不一样,其它设定都相同,大部分情况可以直接替换掉原来基于 MNIST 训练的算法代码,而不需要额外修改。由于 Fashion MNIST 图片识别相对于 MNIST 图片更难,因此可以用于测试稍复杂的算法性能。
▍图 1 Fashion MNIST 数据集
在 TensorFlow 中,加载 Fashion MNIST 数据集同样非常方便,利用 keras.datasets.fashion_mnist.load_data()函数即可在线下载、管理和加载。代码如下:
2)利用编码器降维
我们利用编码器将输入图片 x∈R784降维到较低维度的隐藏向量 h∈R20,并基于隐藏向量利用解码器重建图片,自编码器模型如图 2 所示,编码器由 3 层全连接层网络组成,输出节点数分别为 256、128、20,解码器同样由 3 层全连接网络组成,输出节点数分别为 128、256、784。
▍图 2 Fashion MNIST 自编码器网络结构
首先是编码器子网络的实现。利用 3 层的神经网络将长度为 784 的图片向量数据依次降维到 256、128,最后降维到 h_dim 维度,每层使用 ReLU 激活函数,最后一层不使用激活函数。代码如下:
3)创建解码器
然后再来创建解码器子网络,这里基于隐藏向量 h_dim 依次升维到 128、256、784 长度,除最后一层,激活函数使用 ReLU 函数。解码器的输出为 784 长度的向量,代表了打平后的大小图片,通过 Reshape 操作即可恢复为图片矩阵。代码如下:
4)自编码器
上述的编码器和解码器 2 个子网络均实现在自编码器类 AE 中,我们在初始化函数中同时创建这两个子网络。代码如下:
接下来将前向传播过程实现在 call 函数中,输入图片首先通过 encoder 子网络得到隐藏向量 h,再通过 decoder 得到重建图片。依次调用编码器和解码器的前向传播函数即可,代码如下:
5)网络训练
自编码器的训练过程与分类器的基本一致,通过误差函数计算出重建向量与原始输入向量之间的距离,再利用 TensorFlow 的自动求导机制同时求出 encoder 和 decoder 的梯度,循环更新即可。
首先创建自编码器实例和优化器,并设置合适的学习率。例如:
这里固定训练 100 个 Epoch,每次通过前向计算获得重建图片向量,并利用 tf.nn.sigmoid_cross_entropy_with_logits 损失函数计算重建图片与原始图片直接的误差,实际上利用 MSE 误差函数也是可行的。代码如下:
6)图片重建
与分类问题不同的是,自编码器的模型性能一般不好量化评价,尽管值可以在一定程度上代表网络的学习效果,但我们最终希望获得还原度较高、样式较丰富的重建样本。因此一般需要根据具体问题来讨论自编码器的学习效果,比如对于图片重建,一般依赖于人工主观评价图片生成的质量,或利用某些图片逼真度计算方法(如 Inception Score 和 Frechet Inception Distance)来辅助评估。
为了测试图片重建效果,我们把数据集切分为训练集与测试集,其中测试集不参与训练。我们从测试集中随机采样测试图片,经过自编码器计算得到重建后的图片,然后将真实图片与重建图片保存为图片阵列,并可视化,方便比对。代码如下:
02、运行结果
图片重建的效果如图 3~5 所示,其中每张图片的左边 5 列为真实图片,右边 5 列为对应的重建图片。可以看到,第一个 Epoch 时,图片重建效果较差,图片非常模糊,逼真度较差;随着训练的进行,重建图片边缘越来越清晰,第 100 个 Epoch 时,重建的图片效果已经比较接近真实图片。
▍图 3 第 1 个 Epoch
▍图 4 第 10 个 Epoch
▍图 5 第 100 个 Epoch
这里的 save_images 函数负责将多张图片合并并保存为一张大图,这部分代码使用 PIL 图片库完成图片阵列逻辑,代码如下:
3、完整代码
版权声明: 本文为 InfoQ 作者【TiAmo】的原创文章。
原文链接:【http://xie.infoq.cn/article/f06807349eb0bd120cec0368d】。文章转载请联系作者。
评论