写点什么

AI 简报 -GAN 和 CGAN

  • 2022 年 7 月 13 日
  • 本文字数:2693 字

    阅读完需:约 9 分钟

AI简报-GAN和CGAN

1.GAN 基础

Generative Adversarial Nets 是 GAN 的开山之作,通过对抗的方式,去学习数据分布的生成式模型。GAN 是无监督的过程,能够捕捉数据集的分布,以便于可以从随机噪声中生成同样分布的数据。

GAN 涉及两个模型,一个是 D 判别式模型和 G 生成式模型

  • D 判别式模型:学习真假边界,判断数据是真的还是假的

  • G 生成式模型:学习数据分布并生成数据

所谓对抗是指生成器 G 不断的生成数据让判别器 D 进行判断,一开始生成器 G 生成的数据很容易就被判别器揪出来,但是随着生成器 G 学会了判别器 D 的判别一些标准或者特征,也学习了这些特征生成新的数据,再次让判别器判断。如此迭代,直到生成器 G 生成的数据,判别器无法判断真假,以假乱真的情况下,生成器 G 就出师了。



GAN 经典的 loss 如下(minmax 体现的就是对抗)



训练过程:

  • 首先先训练多轮的判别器 D:生成器 G 不变,使得判别器 D,能够正确区分真假(真的数据是要学习分布的数据,假的数据是随机分布通常是正态分布的数据),从 loss 看就是最大化真的数据和最大化 1-假的数据

  • 在训练生成器 G:判别器不变,原文是最大化假的 loss



结果展示:



2.CGAN 介绍

GAN 的目的是学习输入数据的分布特征,然后再根据这个特征根据随机噪声输入生成相同分布的数据。理想情况下,GAN 可以生成与训练数据相似的数据,但是每次生成的图像是不受控制的。CGAN 是指在 GAN 上加一个条件,让 GAN 的生成受这个输入条件的影响。比如在论文中Conditional Generative Adversarial Nets加入的条件是希望生成的类别的标签; 当然这个这个条件可以是文本,图片等。




  • 条件文本:Text-to-Image 根据文本来生成对应的图片

  • 条件是图片:Image-to-Image (pix2pix: Image-to-Image Translation with Conditional Adversarial Networks):如下: 



可见,CGAN 是属于监督学习的过程,需要 paired data。


2.1 CGAN 实现说明

CGAN 要求 paired 数据, 判别器和生成器都比 GAN 多一个 condition。实现的时候只要把这个 condition 以 concat 的方式加到原来噪声输入即可。见下面李宏毅先生的课件:


  • 生成器

class Generator(nn.Module):    def __init__(self):        super(Generator, self).__init__()
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers
self.model = nn.Sequential( *block(opt.latent_dim + opt.n_classes, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() )
def forward(self, noise, labels): # Concatenate label embedding and image to produce input gen_input = torch.cat((self.label_emb(labels), noise), -1) img = self.model(gen_input) img = img.view(img.size(0), *img_shape) return img
复制代码
  • 判别器

class Discriminator(nn.Module):    def __init__(self):        super(Discriminator, self).__init__()
self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
self.model = nn.Sequential( nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 512), nn.Dropout(0.4), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 1), )
def forward(self, img, labels): # Concatenate label embedding and image to produce input # 这里的concat如果是图片,也是采取concat的方式 d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1) validity = self.model(d_in) return validity
复制代码
  • train

for epoch in range(opt.n_epochs):    for i, (imgs, labels) in enumerate(dataloader):
batch_size = imgs.shape[0]
# Adversarial ground truths valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
# Configure input real_imgs = Variable(imgs.type(FloatTensor)) labels = Variable(labels.type(LongTensor))
# ----------------- # Train Generator # -----------------
optimizer_G.zero_grad()
## 生成的图片与label相同的情况,判别的loss才是最小的 # Sample noise and labels as generator input z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim)))) gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
# Generate a batch of images gen_imgs = generator(z, gen_labels)
# Loss measures generator's ability to fool the discriminator validity = discriminator(gen_imgs, gen_labels) g_loss = adversarial_loss(validity, valid)
g_loss.backward() optimizer_G.step()
# --------------------- # Train Discriminator # --------------------- # 两部分loss, 真实配对的loss, 和构造的配对的loss optimizer_D.zero_grad()
# Loss for real images validity_real = discriminator(real_imgs, labels) d_real_loss = adversarial_loss(validity_real, valid)
# Loss for fake images validity_fake = discriminator(gen_imgs.detach(), gen_labels) d_fake_loss = adversarial_loss(validity_fake, fake)
# Total discriminator loss d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward() optimizer_D.step()
复制代码

2.2 CGAN 的应用

  









发布于: 刚刚阅读数: 4
用户头像

公众号:人工智能微客(weker) 2019.11.21 加入

人工智能微客(weker)长期跟踪和分享人工智能前沿技术、应用、领域知识,不定期的发布相关产品和应用,欢迎关注和转发

评论

发布
暂无评论
AI简报-GAN和CGAN_深度学习_AIWeker-人工智能微客_InfoQ写作社区