BT

如何利用碎片时间提升技术认知与能力? 点击获取答案

如何应用TFGAN快速实践生成对抗网络?

| 作者 武维 关注 6 他的粉丝 发布于 2018年6月4日. 估计阅读时间: 18 分钟 | QCon上海2018 关注大数据平台技术选型、搭建、系统迁移和优化的经验。

前言

生成对抗网络(Generative Adversarial Nets ,GAN)目前已广泛应用于图像生成、超分辨率图片生成、图像压缩、图像风格转换、数据增强、文本生成等场景。越来越多的研发人员从事GAN网络的研究,提出了各种GAN模型的变种,包括CGAN、InfoGAN、WGAN、CycleGAN等。为了更容易地应用及实践 GAN模型,谷歌开源了名为TFGAN的TensorFlow库,可快速实践各种GAN模型。本文主要讲解TFGAN如何应用于原生GAN、CGAN、InfoGAN、WGAN等场景,如下所示:

其中,原生GAN生成的Mnist图像不可控:CGAN可按照数字标签生成相应标签的数字图像;InfoGAN可认为是无监督的CGAN,前两行表示用分类潜变量控制数字的生成类别,中间两行表示用连续型潜变量控制数字的粗细,最后两行表示用连续型潜变量控制数字的倾斜方向;ImageToImage是CGAN的一种,实现图像的风格转换。

生成对抗网络与TFGAN

GAN由Goodfellow 首先提出,主要由两部分构成:Generator(生成器),简称G;Discriminator(判别器), 简称D。生成器主要用噪声 z 生成一个类似真实数据的样本,样本越逼真越好;判别器用于估计一个样本来自于真实数据还是生成数据,判定越准确越好。如下图所示:

上图中,对于真实的采样数据,通过判别网络后,生成D(x)。D(x)的输出是0-1范围内的一个实数,用来判断这个图片是一个真实图片的概率是多大。这样对于真实数据,D(x)越接近1越好。对于随机噪声z,通过生成网络G后,G将这个随机噪声转化为生成数据x。如果是图片生成问题,G网络的输出就是一张生成的假图片,用G(z)表示。判别模型D要使得D(G(z))接近与0,即能够判断生成的图片是假的;生成模型G要使得D(G(z))接近于1,即要能够要欺骗判别模型,使得D认为G(z)生成的假数据是真的。这样通过判别模型D和生成模型G的博弈,使得D无法判断一张图片是生成出来的还是真实的而结束。

假设P_r和P_g分别代表真实数据的分布与生成数据的分布,这样判别模型的目标函数可以表示为:

而生成模型的是让判别模型D无法区别真实数据与生成数据,这样优化目标函数为:

TFGAN库的地址为https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan,主要包含以下几个组件:

  1. 核心架构,主要包括创建TFGAN模型,添加Loss值,创建训练operation,运行训练operation。
  2. 常用操作,主要提供了梯度修剪操作,归一化操作及条件化操作等。
  3. 损失函数,主要提供了GAN中常用的损失和惩罚函数,如 Wasserstein 损失、梯度惩罚、互信息惩罚等。
  4. 模型评估,提供了Inception Score和Frechet Distance指标,用于评估无条件生成模型。
  5. 示例,谷歌同时开源了常用的GAN网络示例代码,包括unconditional GAN,conditional GAN, InfoGAN,WGAN等。相关用例可从https://github.com/tensorflow/models/tree/master/research/gan/地址下载。

使用TFGAN库训练GAN网络主要包含如下几个步骤:

1. 确定GAN网络的输入,如下所示:

images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

2. 设定GANModel中的生成模型和判别模型,如下所示:

gan_model = tfgan.gan_model(
    generator_fn=mnist.unconditional_generator,  # you define
    discriminator_fn=mnist.unconditional_discriminator,  # you define
    real_data=images,
    generator_inputs=noise)

3. 设定GANLoss中的损失方程,如下所示:

gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss)

4. 设定GANTrainOps中的训练操作,如下所示:

train_ops = tfgan.gan_train_ops(
    gan_model,
    gan_loss,
    generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))

5. 运行模型训练,如下所示:

tfgan.gan_train(
    train_ops,
    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
    logdir=FLAGS.train_log_dir)

CGAN

CGAN(Conditional Generative Adversarial Nets),针对GAN本身不可控的缺点,加入监督信息,训练从无监督变成有监督,指导GAN网络进行生成。例如输入分类的标签,可生成相应标签的图像。这样CGAN的目标方程可以转换为:

其中,y是加入的监督信息,D(x|y)表示在y的条件下判定真实数据x,D(G(z|y))表示在y的条件下判定生成数据G(z|y)。例如,MNIST数据集可根据数字label信息,生成相应标签的图片;人脸生成数据集,可根据性别、是否微笑、年龄等信息,生成相应的人脸图片。CGAN的架构如下图所示:

在TFGAN中提供了,基于one_hot_labels变量和输入tensor生成condition tensor的API,如下所示:

tfgan.features.condition_tensor_from_onehot (tensor, one_hot_labels, embedding_size)

其中,tensor为输入数据,one_hot_labels为onehot标签,shape为[batch_size, num_classes],embedding_size为每个label对应的embedding大小,返回值为condition tensor。

ImageToImage

Phillip Isola等提出了基于CGAN的图片生成图片的对抗神经网络《Image-to-Image Translation with Conditional Adversarial Networks》。网络设计的基本思想如下所示:

其中,x为输入的线条图,G(x)为生成图片,y为线条图x对应渲染后的真图片,生成模型G用于生成图片,判断模型D用于判定生成图片的真假。判别网络能够最大化判断(x,y)的数据为真,判断(x,G(x))数据为假。而生成网络使得判别网络判断(x,G(x))数据为真,从而进行生成模型和判别模型的相互博弈。为了使生成模型不仅能够欺骗判别模型,还要使得生成图像要像真实图片,这样在目标函数中加入了真实图像和生成图像的L1距离,如下所示:

TFGAN库,提供了ImageToImage生成对抗网络的相关损失方程API使用示例,如下所示:

# 定义真实数据与生成数据的L1损失
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) / FLAGS.patch_size ** 2

# gan_loss为目标函数损失
gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

InfoGAN

在GAN中,生成器用噪声z生成数据时,没有加任何的条件限制,很难用z的任何一个维度信息表示相关的语义特征。所以在数据生成过程中,无法控制什么样的噪声z可以生成什么样的数据,在很大程度上限制了GAN的使用。InfoGAN可以认为是无监督的CGAN,在噪声z上增加潜变量c,使得生成模型生成的数据与浅变量c具有较高的互信息,其中Info就是代表互信息的含义。互信息定义为两个熵的差值,H(x)是先验分布的熵,H(x|y)代表后验分布的熵。如果x,y是相互独立的变量,那么互信息的值为0,表示x,y没有关系;如果x,y有相关性,那么互信息大于0。这样在已知y的情况下,可以推断出那些x的值出现高。这样InfoGAN的目标方程为:

InfoGAN的网络结构如下所示:

上图中InfoGAN与GAN的区别在于,对应判别网络的输出D(x),生成变分分布Q(c|x),从而能用Q(c|x)来逼近P(c|x),从而增大生成数据与潜变量c的互信息。

TFGAN中提供了InfoGan相关API,如下所示:

#通过tfgan.infogan_model,定义infogan模型
infogan_model = tfgan.infogan_model(
    generator_fn=generator_fn,
    discriminator_fn=discriminator_fn,
    real_data=real_images,
    unstructured_generator_inputs=unstructured_inputs,
    structured_generator_inputs=structured_inputs)

#通过tfgan.gan_loss,生成infogan模型的loss值:
infogan_loss = tfgan.gan_loss(
    infogan_model,
    gradient_penalty_weight=1.0,
    mutual_information_penalty_weight=1.0)

#InfoGan的Loss值为在GAN的loss值上,加上互信息I(c;G(z,c)),TFGAN中提供了互信息计算的API,如下所示。其中structured_generator_inputs为潜变量的噪音信息,predicted_distributions为变分分布Q(c|x)。

def mutual_information_penalty(structured_generator_inputs, predicted_distributions)

WGAN

Martin Arjovsky等提出了WGAN(Wasserstein GAN),解决了传统GAN训练困难、生成器和判别器的loss很难指示训练进程、生成样本缺乏多样性等问题,主要有以下优点:

  1. 能够平衡生成器和判别器的训练程度,使得GAN的模型训练稳定。
  2. 能够保证生产样本的多样性。
  3. 提出使用Wasserstein距离来衡量模型训练的程度,数值越小表示训练得越好,成器生成的图像质量越高。

WGAN的算法与原始GAN算法的差异主要体现在:

  1. 去掉判别模型最后一层的sigmoid操作。
  2. 生成模型和判别模型的loss值不取log操作。
  3. 每次更新判别模型的参数之后把模型参数的绝对值截断到不超过固定常数c。
  4. 使用RMSProp算法,不用基于动量的优化算法,例如momentum和Adam。

WGAN的算法结构如下所示:

TFGAN中提供了WGan相关API,如下所示:

#生成网络损失方程
generator_loss_fn=tfgan_losses.wasserstein_generator_loss
#判别网络损失方程
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss

总结

本文首先介绍了生成对抗网络和TFGAN,生成对抗网络模型用于图像生成、超分辨率图片生成、图像压缩、图像风格转换、数据增强、文本生成等场景;TFGAN是TensorFlow库,用于快速实践各种GAN模型。然后讲解了CGAN、ImageToImage、InfoGAN、WGAN模型的主要思想,并对关键技术进行了分析,主要包括目标函数、网络架构、损失方程及相应的TFGAN API。用户可基于TFGAN快速实践生成对抗网络模型,并应用到工业领域中的相关场景。

参考文献

[1] Generative Adversarial Networks.
[2] Conditional Generative Adversarial Nets.
[3] InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets.
[4] Wasserstein GAN.
[5] Image-to-Image Translation with Conditional Adversarial Networks.
[6] https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan.
[7] https://github.com/tensorflow/models/tree/master/research/gan.

作者简介

武维(微信:allawnweiwu):博士,现为IBM架构师。主要从事深度学习平台及应用研究,大数据领域的研发工作。

评价本文

专业度
风格

您好,朋友!

您需要 注册一个InfoQ账号 或者 才能进行评论。在您完成注册后还需要进行一些设置。

获得来自InfoQ的更多体验。

告诉我们您的想法

允许的HTML标签: a,b,br,blockquote,i,li,pre,u,ul,p

当有人回复此评论时请E-mail通知我
社区评论

允许的HTML标签: a,b,br,blockquote,i,li,pre,u,ul,p

当有人回复此评论时请E-mail通知我

允许的HTML标签: a,b,br,blockquote,i,li,pre,u,ul,p

当有人回复此评论时请E-mail通知我

讨论

登陆InfoQ,与你最关心的话题互动。


找回密码....

Follow

关注你最喜爱的话题和作者

快速浏览网站内你所感兴趣话题的精选内容。

Like

内容自由定制

选择想要阅读的主题和喜爱的作者定制自己的新闻源。

Notifications

获取更新

设置通知机制以获取内容更新对您而言是否重要

BT