写点什么

带你徒手完成基于 MindSpore 的 CycleGAN 实现

  • 2022 年 8 月 16 日
    中国香港
  • 本文字数:2836 字

    阅读完需:约 9 分钟

带你徒手完成基于MindSpore的CycleGAN实现

本文分享自华为云社区《基于MindSpore的CycleGAN介绍和实现》,作者: Tianyi_Li 。

前言


我们这次介绍下著名的 CycleGAN,同时提供了基于 MindSpore 的代码,方便大家运行验证。

CycleGAN 的介绍


CycleGAN 图像翻译模型,由两个生成网络和两个判别网络组成,通过非成对的图片将某一类图片转换成另外一类图片,可用于风格迁移,效果演示如下图所示:



CycleGAN 是 GAN 的一种,那什么是 GAN 呢?


生成对抗网络(Generative Adversarial Network, 简称 GAN) 是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由 lan Goodfellow 等人在 2014 年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。


好了,我们已经对 GAN 有了大体的了解,下面说回 CycleGAN。


CycleGAN 由两个生成网络和两个判别网络组成,生成网络 A 是输入 A 类风格的图片输出 B 类风格的图片,生成网络 B 是输入 B 类风格的图片输出 A 类风格的图片。生成网络中编码部分的网络结构都是采用 convolution-norm-ReLU 作为基础结构,解码部分的网络结构由 transpose convolution-norm-ReLU 组成,判别网络基本是由 convolution-norm-leaky_ReLU 作为基础结构,详细的网络结构可以查看 network/CycleGAN_network.py 文件。生成网络提供两种可选的网络结构:Unet 网络结构和普通的编码器-解码器网络结构。生成网络损失函数由 LSGAN 的损失函数,重构损失和自身损失组成,判别网络的损失函数由 LSGAN 的损失函数组成。


CycleGAN 最经典的地方是设计和提出了循环一致性损失。以黑白图片上色为例,循环一致性就是:黑白图(真实)—>网络—>彩色图—>网络—>黑白图(造假)。为了保证上色后的彩色图片中具有原始黑白图片的所有内容信息,文章中将生成的彩色图像还原回去,生成造假的黑白图,通过损失函数来约束真实白图和造假黑白图一致,达到图像上色的目的。除此之外,CycleGAN 不像 Pix2Pix 一样,需要使用配对数据进行训练,CycleGAN 直接使用两个域图像进行训练,而不用建立每个样本和对方域之间的配对关系,这就厉害了,一下子让风格迁移任务变得简单很多。


看一下 CycleGAN 的网络结构图:



如果想了解更多详情,可以阅读 CycleGAN 的原论文,推荐读一读,会有更深刻和更清楚的理解,下面给出链接:

https://arxiv.org/abs/1703.10593

CycleGAN 的实现

代码和数据集


这里我提供了一个包含代码和数据集的仓库链接:https://git.openi.org.cn/tjulitianyi/CycleGAN_MindSpore,但是更建议使用最新版本代码,见下方特别说明。


特别说明:我们将在华为云 ModelArts 的 NoteBook,基于 MindSpore-GPU 1.8.1 运行 CycleGAN 的代码,因为云环境的更新不确定性,所以运行可能会报错,这时可以参考如下最新代码:https://gitee.com/mindspore/models/tree/master/research/cv/CycleGAN。


需要提醒大家的是,必须需要使用 MindSpore 1.8.0 以及以上的版本,之前版本会报错,因为某些 API 不支持。而最新的 1.8.1 版本有时也会报错,报错信息如下,怀疑可能是代码的设置有些问题:



目前 ModelArts 最高支持到 MindSpore 1.7,我们需要自行安装最新的 MindSpore 1.8.1 版本。

先来看看我使用的 NoteBook 环境:



这里特别提醒大家,NoteBook 是要花钱的,我选择的单卡 Tesla V100 大约每小时 28 元,也有更便宜的,大概每小时 8 元的单卡 Tesla P100,请大家根据自身情况选择,千万注意使用情况,别欠费了。

准备环境


下面进入 NoteBook,打开一个终端:



先来看看我们的信息和显卡 CUDA Version:



我们看到 CUDA Version 是 10.2,下面到 MindSpore 官网看看安装教程,我们需要安装 MindSpore 1.8.1,但是没有 CUDA 10.2 对应的版本,这里就选择就近的 CUDA 10.1 版本了。



在终端执行如下命令:


pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/1.8.1/MindSpore/gpu/x86_64/cuda-10.1/mindspore_gpu-1.8.1-cp37-cp37m-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple
复制代码


下载速度很快,安装速度也是非常快:



最后运行显示如下信息,表示安装成功了:


获取代码


接下来下载代码,执行如下命令(由于要下载整个仓库,时间有点长):


git clone https://gitee.com/mindspore/models.git
复制代码


命令运行截图:



下面我们将感兴趣的 CycleGAN 代码拷贝到当前目录下,执行如下命令:


cp -r models/research/cv/CycleGAN/ ./
复制代码

准备数据集


下面进入 CycleGAN 目录:


cd CycleGAN
复制代码



我们这里使用的是 monet2photo 数据集,由于直接在 ModelArts 的 NoteBook 下载速度很慢,所以建议大家下载到本地,再上传到 NoteBook 的 CycleGAN/data 目录下,下载链接为:https://s3.openi.org.cn/opendata/attachment/7/b/7beb4534-6e79-463e-a7c6-032510bab215?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=1fa9e58b6899afd26dd3%2F20220814%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220814T085624Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B filename%3D”monet2photo.zip“&X-Amz-Signature=20fbfd9c798701efcbf21d811f3dfdd6b8d5744f388c799bc38715f7fe78c783


上传完成后,解压数据集即可。我的运行截图如下图所示:



启动训练


注意,请在 CycleGAN 的目录下启动训练,如下图所示:



我是在 GPU 下的单卡训练,所以启动训练的命令为:


python train.py --platform GPU --device_id 0 --model ResNet --max_epoch 200 --dataroot ./data/monet2photo/ --outputs_dir ./outputs
复制代码


运行截图为:



可以看到已经成功启动训练,打印出 loss,此时我是用的 Tesla V100 显卡大约占了 4GB 显存,利用率接近 100%,此时来看不适合用 Tesla V100 来跑,未能发挥其大显存的优势,而其计算能力其实一般。CycleGAN 模型训练比较费时间,请注意花费,预计完成全部 200epoch 的训练需要 72 小时以上。

评估模型


python eval.py --platform GPU --device_id 0 --model ResNet --G_A_ckpt ./outputs/ckpt/G_A_200.ckpt --G_B_ckpt ./outputs/ckpt/G_B_200.ckpt
复制代码


注意,这里的.ckpt 模型名称,请根据实际训练生成的具体轮数的模型名称太难写,比如目前只保存了 20epoch 的模型,那上述命令的 200 就应该改成 20。


更多命令或适配其他硬件平台和多卡情况,可参考 scripts 文件夹下脚本。

结语


我们简单介绍了著名的 CycleGAN,给出了基于 MindSpor 的完整代码,并带着大家跑了一遍,目前有些问题,后续会更新。作为经典的 GAN 的一种,CycleGAN 有很多值得我们学习的地方,还需要深入分析挖掘,以鉴今事。


关于代码运行的问题,可以到官仓提交 issue 求助,下为链接:https://gitee.com/mindspore/models/issues


点击关注,第一时间了解华为云新鲜技术~

发布于: 刚刚阅读数: 3
用户头像

提供全面深入的云计算技术干货 2020.07.14 加入

华为云开发者社区,提供全面深入的云计算前景分析、丰富的技术干货、程序样例,分享华为云前沿资讯动态,方便开发者快速成长与发展,欢迎提问、互动,多方位了解云计算! 传送门:https://bbs.huaweicloud.com/

评论

发布
暂无评论
带你徒手完成基于MindSpore的CycleGAN实现_人工智能_华为云开发者联盟_InfoQ写作社区