爱吧机器人网 » 技术 > 机器学习 > 正文

深度学习之生成式对抗网络(GAN)入门指南

近年来,神经网络已经取得了很大进展,从能感知图像、声音,到转录人类自然语言,它的发展为我们开启了一扇崭新的大门。但即便如此,我们现在已实现的“智能”距离真正的智能还有不小的差距,机器人能通过传感器收集光谱、声波信息,但它们并不能做到“理解”。这也是许多人强调社会还需要人类的“创造力”的原因:机器人无法自己组织简单语言解释概念,也不能像艺术家一样进行创作。
 
当然,这个想法放到现在可能有些过时,因为自2014年Ian J. Goodfellow等人首先提出生成对抗网络(Generative Adversarial Network,GAN)后,在短短三年间,这个还不太成熟的深度学习模型就已经成了无监督学习最具前景的方法之一,让许多原本以为需要“创造力”的行为实现了自动化。
 
从零学习系列第二篇:生成敌对网络(GAN)入门指南,来自数据科学爱好者、深度学习开发者Faizan Shaikh。
 
本文将介绍GAN的基础概念及其工作方式,并辅之以有趣案例的实现方法和重要资源,方便初学者训练、使用。
 
目录
 
什么是GAN?
GAN的工作方式
如何训练GAN
GAN的痛点
GAN的应用
 
什么是GAN?
 
深度学习领域的大牛Yann LeCun曾在Quora会议上表示:
 
在我看来,(GAN)是近10年来ML领域提出的最有趣的想法。
 
这样的评价令人振奋,但似乎对于理解没有任何用处,作为普通的数据科学家,我们眼中的GAN也许更多的是一些实际意义。
 
那么,什么是GAN?对于这个问题,我们先来打个比方:如果你想改善某些事情,比如说提高下棋水平,你会怎么做?相信普通人的回答都是找一个比自己更强的对手并与之竞争,分析战术技巧、积累经验,直至击败他。GAN的思路也一样,为了成为一个下棋高手(生成模型generator),我们需要一个更强大的对手(判别模型discriminator)。

\
 
生成器和判别器的关系可以说是伪造者和调查者的关系。以伪造名画为例,生成器的任务是仿照原画生成赝品,如果蒙混过关(输出),他会得到丰厚奖励。而判别器的任务则是找出赝品和原画的差异,他会从原画中提取特征作为比较内容,以此评估生成的图像是否真实。
 
如果说这还不够形象,让我们借用微软亚洲研究院的描述:合格男友养成计划。
 
男:哎,你看我给你拍的好不好?
 
女:这是什么鬼,你不能学学XXX的构图吗?
 
男:哦
 
……
 
男:这次你看我拍的行不行?
 
女:你看看你的后期,再看看YYY的后期吧,呵呵
 
男:哦
 
……
 
男:这次好点了吧?
 
女:呵呵,我看你这辈子是学不会摄影了
 
……
 
男:这次呢?
 
女:嗯,我拿去当头像了
 
在这个情景中,我们的目标是把男友培养成一个合格的~~陈老师~~摄(拍)影(照)师(的)。产出照片的男友是生成器,鉴别照片质量、审美要求更高的女友是判别器。可以发现,在训练时,每当男友上交一张照片,女友就会指出它们和目标特征(构图、后期)的差距,之后男友根据反馈进行学习,经过数轮重复后,最后他拍出了令人满意的照片。
 
当然,如果女友水平过高,或者太过天马行空,而男友只是个木讷的“老实人”,那么恭喜你,他们的这段关系(GAN)已经崩溃了。
 
GAN的工作方式
 
现在,我们已经大致理解了GAN的概念,可以进一步了解它的工作本质了。
 
如下图所示,GAN主要由生成器神经网络(Generator Network)和判别器神经网络(Discriminator Network)构成的:

\
生成器神经网络的任务接收随机输入并尝试生成一个数据样本输入判别器,而判别器神经网络的任务是同时从真实数据和生成器处接收输入,并预测输入是真实的还是生成的。在上图中,我们可以看到生成器G(z)从随机输入p(z)中取了一个样本z,由此产生一个数据输入判别器神经网络D(x),与此同时,D(x)也从真实数据pdata(x)中获得了输入。这之后,D(x)对两个输入用激活函数(sigmoid)进行二元分类,输出范围在0—1之间的概率。
 
让我们再理一理图中的符号:
 
Pdata(x):真实数据的分布;
 
X:pdata(x)中的样本;
 
P(z):生成器数据分布;
 
Z:p(z)中的样本;
 
G(z):生成器神经网络;
 
D(x):判别器神经网络。
 
这就是一个基础的GAN,而训练它的方式就是让生成器和判定器互相对抗。这一过程可以用数学来表示:

\
 
如上式所示,判定器的目标的是使V最大化,而生成器的目标是使V最小化(真实数据与生成数据之间的差异最小化)。换句话说,这是发生在生成器和判定器之间的猫鼠游戏。
 
正如我们之前提到的,GAN要训练D、G两个神经网络,我们先固定G看D。由于V(D, G)表示的是差异大小,因此对于判别器D,它希望V越接大越好。其中第一项——将Pdata(x)数据映射到判定器内的熵——因为它是真的,所以想被分成1;而对于第二项,它是是P(z)数据映射到生成器内,由此生成假样本输入判定器内的熵,如果D(G(z))被错分为1,那V就无穷小了,所以我们要它接近0。
 
之后,我们固定D看G。由于第一项不含G成了常数,所以我们可以直接看第二项。可以发现,既然我们的目标是使V最小化,那就是让第二项最小化,那么D(G(z))就该无限靠近1。
 
注:这种训练GAN的方法受极大极小博弈(minimax game)启示。
 
如何训练GAN
 
广泛地说,GAN的训练主要由两部分组成,而且它们还是按顺序进行的。
 
Pass 1:固定生成器训练判定器(固定意味着将生成器的结果设置为假,神经网络只做正常传播,不做反向传播);

\
 
Pass 2:固定判定器训练生成器。
 

\
第一步:定义问题。
 
确定你想生成的对象,是假图像还是加文本,定义问题并搜集数据。
 
第二步:定义GAN的体系结构。
 
为你的GAN选定一种结构,比如你的生成器和判定器是多层感知器还是卷积神经网络。这主要取决于你想解决什么问题。
 
第三步:在真实数据上训练判别器,epoch=n。
 
训练判别器在真实数据上做出正确预测,轮次n可以是大于等于1的任意自然数。
 
第四步:为生成器生成假数据,并在假数据上训练判别器。
 
训练判别器正确鉴别假数据为假。
 
第五步:用判别器的输出训练生成器。
 
将判别器的预测结果作为生成器的目标,训练生成器去“欺骗”判别器。
 
第六步:重复步骤3—5。
 
第七步:手动检查生成的假数据是否符合期望:如果符合,停止训练;如果有瑕疵,重回第三步。
 
检查数据是否伪造的最好方法是手动检查,这时你可以评估自己的GAN是否运行良好。
 
现在,你只需深呼吸一口并静待结果,想象一下,如果有一个功能齐全的生成器,那你就几乎能“伪造”任何东西了。事实上,现在比较常见的应用是生成假新闻、创作情节令人匪夷所思的小说、设置自动答录等。
 
GAN的痛点
 
看到这里,你可能会问,既然我们已经有了这样强大的框架,那为什么没有实现什么重大突破呢?事实上,这是因为我们对GAN的理解还停留在表面,即使是“GANs之父”Ian J. Goodfellow,他也无法清除构建一个“足够好”的GAN的过程中的层层阻碍。在他去年发表的论文Improved Techniques for Training GANs中,他还在探讨该如何训练一个GAN。
 
现在GAN所面临的最重要的问题是稳定性。如果你训练了一个GAN,生成器很弱小,但是判别器却异常强大,你就会发现训练后模型性能很差,因为生成器无法根据反馈有效训练,而这也反过来影响了整个网络。这一点是由损失函数缺失造成的。之前我们提到过,GAN的训练方法启发自极大极小博弈,不用计算损失,就意味着神经网络并不知道自己是否在进步。
 
另一方面,如果判别器不够强,鉴别范围过于宽泛,那生成器就可以自由生成任何图像,这样导致的训练结果也是一个无用的GAN。让我们回到算式那一节,P(z)是符合分布的生成数据,因为没有预先建模,所以这种随机采样的方式在理论上更接近真实数据,但是这样做的弊端是当面对较大数据时,神经网络缺少约束,会变得过于自由,而且不可控制。
 
此外,GAN的稳定性问题还体现在它的整体收敛问题上。一方面,生成器和判别器在互相对抗;另一方面,其实它们也互相依赖着进行有效训练。如果一方出现问题,那整个系统就会失败,所以你必须保证它们不崩溃。
 
这有点像电子游戏波斯王子(Prince of Persia)的情景,王子必须防卫影子的攻击,以免被杀死。如果他杀死了影子,他也会死;如果他什么都不做,那他肯定会死。
 
下面还有一些GAN面临的应用问题:
 
注:以下图像是在ImageNet数据集上训练的GAN生成的。
 
计数问题。GAN无法区分某个位置具体该生成多少特定对象。如下图所示,这些“动物”头部的眼睛太多了;

\
 
透视问题。GAN无法适应3D对象,它分辨不了前景和背景的透视差异。如下图所示,它把3D对象转成了3D表示;

\
 
全局构造问题。和透视问题一样,GAN也完全把握不了全局构造。例如在下图中,它生成了一头奇怪的牛,它靠两条后腿站着,但是又四脚着地。

\
 
针对这些问题,现在我们也有了DCGAN、WassersteinGAN等训练更精确模型的方法。
 
实现一个玩具GAN
 
看完理论,让我们实现一个GAN来加深学习印象。
 
任务:训练一个能自动生成数字的GAN;
 
数据集:28×28个黑白数字图像,格式为png;
 
前往analyticsvidhya下载数据集(详情请询小编),注意:这也是一个比赛活动,有兴趣的读者可以前往参加,截止时间还有一周左右。
 
设置环境:
 
numpy
 
pandas
 
tensorflow
 
keras
 
keras_adversarial
 
在开始写打码前,我们先用伪代码了解下内部实现机制:

\
不是唯一实现,还有多种更新/改进版
 
(鉴于以下内容过长,且全是代码,没有太大的观赏性,因此不放。如有需要,可找头条号“论智”的小编)
 
GAN的应用
 
之前我们介绍了GAN的概念、数学计算、搭建方法等内容,现在可以围观一下当前学界围绕GAN的尖端研究。
 
预测视频的下一帧。你可以在视频序列上训练GAN,并让它预测下一个画面会是什么;
\
 
增加图像分辨率。你可以用GAN生成高清无码图片;

\
 
交互式图像生成。GAN可以实现寥寥几笔就画出令人印象深刻的图片。
 
图像翻译:用一个图像生成另一张图像。如下图所示,左侧图像是传感器扫描到的标签图像、手提包线条画,右侧是经GAN预测的真实街景图像和真实包包;
 
\
由文本生成图像。你可以打字告诉GAN你想要什么,它会为你生成相应对象的图片。
 
\


上一篇:【Science】CMU机器学习系主任:八个关键标准判别深度学习任务成功与否
下一篇:6步创建一个通用机器学习模板
精选推荐

[2017-03-21]  虽然有很多关于机器人取代工人的担心,但哈佛经济学家James Bessen的论文指出,在过去的67年里机器人仅仅淘汰掉人类工作中的一个。在1950 ...

英伟达用联合学习创建医学影像AI 可共享数据和保护隐私
英伟达用联合学习创建医学影像AI 可共享数据和保护隐私

[2019-10-14]  英伟达(Nvidia)和伦敦国王学院(King’s College London)的人工智能研究人员利用联合学习训练了一种用于脑肿瘤分类的神经网络, ...

通过对抗性图像黑入大脑
通过对抗性图像黑入大脑

[2018-03-02]  在上面的图片中,左边是一张猫的照片。在右边,你能分辨出它是同一只猫的图片,还是一张看起来相似的狗的图片?这两张图片之间的区别在于, ...

助力卷积神经网络时空特征学习 史上最大行人重识别视频数据集被提出
助力卷积神经网络时空特征学习 史上最大行人重识别视频数据集被提出

[2017-12-25]  本文提出了一个大型的、长序列的、用于行人重识别的视频数据集,简称LVreID。与现有的同类数据集相比,该数据集具有以下特点:1)长序列:平均每段视频序列长为200帧,包含丰......

麻省理工正研究植物机器人 让植物自主控制机器人
麻省理工正研究植物机器人 让植物自主控制机器人

[2018-12-08]  控制论通常指人类用机器人部件增强自己。我们听说过动物机器人或昆虫机器人,但我们很少听说植物机器人对吧?一个机器人其实是对植物有很大益处的,因为一般植物根本无法移动......

苹果AI主管透露自动驾驶汽车项目关于机器学习方面的进展
苹果AI主管透露自动驾驶汽车项目关于机器学习方面的进展

[2017-12-11]  苹果隐秘的自动驾驶汽车项目多年来一直在转移焦点,但今年似乎正在加速。 4月份,公司获得了在加利福尼亚州进行自动驾驶汽车测试的许可证,而在6月份,苹果公司首席执行官库......

谷歌《Nature》发论文称实现量子霸权 18亿倍速碾压世界最强超算
谷歌《Nature》发论文称实现量子霸权 18亿倍速碾压世界最强超算

[2019-10-23]  谷歌坚称自己已经取得了量子霸权——这标志着计算研究领域的一个重要里程碑。谷歌首次发布声明是在今年9月,虽然遭到竞争对手的质疑,但就 ...

全自动膝关节置换手术机器人被美国FDA批准上市
全自动膝关节置换手术机器人被美国FDA批准上市

[2019-10-14]  美国Think Surgical公司已获得美国食品和药物管理局(FDA)的批准,在美国销售用于全膝关节置换(TKA)的TSolution One®全膝关节应用 ...

本周栏目热点

盘点全球十大最具影响力的机器人摇篮

[1970-01-01]    人工智能(AI)研究现正迅速发展,如无人驾驶汽车、计算机在《危险边缘》智力竞赛节目中获胜、数字私人助手Siri、GoogleNow和语音助手C ...

深度学习反向传播算法(BP)原理推导及代码实现

[2017-12-19]  分析了手写字数据集分类的原理,利用神经网络模型,编写了SGD算法的代码,分多个epochs,每个 epoch 又对 mini_batch 样本做多次迭代计算。这其中,非常重要的一个步骤,......

如何在机器学习项目中使用统计方法的示例

[2018-07-23]  事实上,机器学习预测建模项目必须通过统计学方法才能有效的进行。在本文中,我们将通过实例介绍一些在预测建模问题中起关键作用的统计学方法。...

[2017-08-28]  模拟退火(Simulated Annealing,简称SA)是一种通用概率算法,用来在一个大的搜寻空间内找寻命题的最优解。1、固体退火原理:将固体加温 ...

Machine Learning-感知器分类算法详解

[2018-05-31]  今天我们来讲解的内容是感知器分类算法,本文的结构如下:什么是感知器分类算法,在Python中实现感知器学习算法,在iris(鸢尾花)数据集上训练一个感知器模型,自适应线性神......