爱吧机器人网 » 技术 > 神经网络 > 正文

50行代码玩转生成对抗网络GAN模型!(附源码)

本文为大家介绍了生成对抗网络(Generate Adversarial Network,GAN),以最直白的语言来讲解它,最后实现一个简单的 GAN 程序来帮助大家加深理解。

什么是 GAN?

好了,GAN 如此强大,那它到底是一个什么样的模型结构呢?我们之前学习过的机器学习或者神经网络模型主要能做两件事:预测分类,这也是我们所熟知的。那么是否可以让机器模型自动来生成一张图片、一段语音?而且可以通过调整不同模型输入向量来获得特定的图片和声音。例如,可以调整输入参数,获得一张红头发、蓝眼睛的人脸,可以调整输入参数,得到女性的声音片段,等等。也就是说,这样的机器模型能够根据需求,自动生成我们想要的东西。因此,GAN 应运而生!

GAN,即生成对抗网络,主要包含两个模块:

生成器(Generative Model)
判别器(Discriminative Model)

生成模型和判别模型之间互相博弈、学习产生相当好的输出。以图片为例,生成器的主要任务是学习真实图片集,从而使得自己生成的图片更接近于真实图片,以“骗过”判别器。而判别器的主要任务是找出出生成器生成的图片,区分其与真实图片的不同,进行真假判别。在整个迭代过程中,生成器不断努力让生成的图片越来越像真的,而判别器不断努力识别出图片的真假。这类似生成器与判别器之间的博弈,随着反复迭代,最终二者达到了平衡:生成器生成的图片非常接近于真实图片,而判别器已经很难识别出真假图片的不同了。其表现是对于真假图片,判别器的概率输出都接近 0.5。

对 GAN 的概念还是有点不清楚?没关系,举个生动的例子来说明。

最近,我想学习绘画,是因为看到梵大师的画作,也想画出类似的作品。梵大师的画作像这样:


说画就画,我找来一个研究梵大师作品很多年的王教授来指导我。王教授经验丰富,眼光犀利,市面上模仿梵大师的画作都难逃他的法眼。王教授跟我说了一句话:什么时候你的画这幅画能骗过我,你就算是成功了。

我很激动,立马给王教授画了这幅画:


王教授轻轻扫了一眼,满脸黑线,气的直哆嗦,“0 分!这也叫画?差得太多了!” 听了王教授的话,我开始自我反省,确实画的不咋地,连眼睛、鼻子都没有。于是,又 重新画了一幅:


王教授一看,不到 2 秒钟,就丢下四个字:1 分!重画!我一想,还是不行,画得太差了,就回去好好研究梵大师的画作风格,不断改进,重新创作,直到有一天,我拿着新的画作给王教授看:


王教授看了一看,说有点像了。我得仔细看看。最后,还是跟我说,不行不行,细节太差!继续重新画吧。唉,王教授越来越严格了!我叹了口气回去继续研究,最后将自我很满意的一幅画交给了王教授鉴赏:


这下,王教授戴着眼镜,仔细品析,许久之后,王教授拍着我的肩膀说,画得很好,我已经识别不了真假了。哈哈,得到了王教授的夸奖和肯定,心里美滋滋,终于可以创作出梵大师样的绘画作品了。下一步考虑转行去。

好了,例子说完了(接受大家对我绘画天赋的吐槽)。这个例子,其实就是一个 GAN 训练的过程。我就是生成器,目的就是要输出一幅画能够骗过王教授,让王教授真假难辨!王教授就是判别器,目的就是要识别出我的画作,判断其为假的!整个过程就是“生成 — 对抗”的博弈过程,最终,我(生成器)输出一幅“以假乱真”的画作,连王教授(判别器)都难以区分了。

这就是 GAN,懂了吧。

GAN 模型基本结构

在认识 GAN 模型之前,我们先来看一看 Yann LeCun 对未来深度学习重大突破技术点的个人看法:

The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This is an idea that was originally proposed by Ian Goodfellow when he was a student with Yoshua Bengio at the University of Montreal (he since moved to Google Brain and recently to OpenAI).

This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.


Yann LeCun 认为 GAN 很可能会给深度学习模型带来新的重大突破,是20年来机器学习领域最酷的想法。这几年 GAN 发展势头非常强劲。下面这张图是近几年 ICASSP 会议上所有提交的论文中包含关键词 “generative”、“adversarial” 和 “reinforcement” 的论文数量统计。


数据表明,2018 年,包含关键词 “generative” 和 “adversarial” 的论文数量发生井喷式增长。不难预见, 未来几年关于 GAN 的论文会更多。

下面来介绍一下 GAN 的基本结构,我们已经知道了 GAN 由生成器和判别器组成,各用 G 和 D 表示。以生成图片应用为例,其模型结构如下所示:


GAN 基本模型由 输入 Vector、G 网络、D 网络组成。其中,G 和 D 一般都是由神经网络组成。G 的输出是一幅图片,只不过是以全连接形式。G 的输出是 D 的输入,D 的输入还包含真实样本集。这样, D 对真实样本尽量输出 score 高一些,对 G 产生的样本尽量输出 score 低一些。每次循环迭代,G 网络不断优化网络参数,使 D 无法区分真假;而 D 网络也在不断优化网络参数,提高辨识度,让真假样本的 score 有差距。

最终,经过多次训练迭代,GAN 模型建立:


最终的 GAN 模型中,G 生成的样本以假乱真,D 输出的 score 接近 0.5,即表示真假样本难以区分,训练成功。

这里,重点要讲解一下输入 vector。输入向量是用来做什么的呢?其实,输入 vector 中的每一维度都可以代表输出图片的某个特征。比如说,输入 vector 的第一个维度数值大小可以调节生成图片的头发颜色,数值大一些是红色,数值小一些是黑色;输入 vector 的第二个维度数值大小可以调节生成图片的肤色;输入 vector 的第三个维度数值大小可以调节生成图片的表情情绪,等等。


GAN 的强大之处也正是在于此,通过调节输入 vector,就可以生成具有不同特征的图片。而这些生成的图片不是真实样本集里有的,而是即合理而又没有见过的图片。是不是很有意思呢?下面这张图反映的是不同的 vector 生成不同的图片。


说完了 GAN 的模型之后,我们再来简单看一下 GAN 的算法原理。既然有两个模块:G 和 D,每个模块都有相应的网络参数。

先来看 D 模块,它的目标是让真实样本 score 越大越好,让 G 产生的样本 score 越小越好。那么可以得到 D 的损失函数为:


其中,x 是真实样本,G(z) 是 G 生成样本。我们希望 D(x) 越大越好,D(G(z)) 越小越好,也就是希望 -D(x) 越小越好,-log(1-D(G(z))) 越小越好。从损失函数的角度来说,能够得到上式。

再来看 G 模块,它的目标就是希望其生成的模型能够在 D 中得到越高的分数越好。那么可以得到 G 的损失函数为:


知道了损失函数之后,接下来就可以使用各种优化算法来训练模型了。

动手写个 GAN 模型

接下来,我将使用 PyTorch 实现一个简单的 GAN 模型。仍然以绘画创作为例,假设我们要创造如下“名画”(以正弦图形为例):


生成该“艺术画作”的代码如下:

def artist_works():    # painting from the famous artist (real target)
   r = 0.02 * np.random.randn(1, ART_COMPONENTS)
   paintings = np.sin(PAINT_POINTS * np.pi) + r
   paintings = torch.from_numpy(paintings).float()
   return paintings
然后,分别定义 G 网络和 D 网络模型:

G = nn.Sequential(                  # Generator
   nn.Linear(N_IDEAS, 128),        # random ideas (could from normal distribution)
   nn.ReLU(),
   nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
)

D = nn.Sequential(                  # Discriminator
   nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
   nn.ReLU(),
   nn.Linear(128, 1),
   nn.Sigmoid(),                   # tell the probability that the art work is made by artist
)

我们设置 Adam 算法进行优化:

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

最后,构建 GAN 迭代训练过程:

plt.ion()    # something about continuous plotting

D_loss_history = []
G_loss_history = []
for step in range(10000):
   artist_paintings = artist_works()          # real painting from artist
   G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
   G_paintings = G(G_ideas)                   # fake painting from G (random ideas)
   
   prob_artist0 = D(artist_paintings)         # D try to increase this prob
   prob_artist1 = D(G_paintings)              # D try to reduce this prob
   
   D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
   G_loss = torch.mean(torch.log(1. - prob_artist1))
   
   D_loss_history.append(D_loss)
   G_loss_history.append(G_loss)
   
   opt_D.zero_grad()
   D_loss.backward(retain_graph=True)    # reusing computational graph
   opt_D.step()
   
   opt_G.zero_grad()
   G_loss.backward()
   opt_G.step()
   
   if step % 50 == 0:  # plotting
       plt.cla()
       plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
       plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='#74BCFF', lw=3, label='standard curve')
       plt.text(-1, 0.75, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 8})
       plt.text(-1, 0.5, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 8})
       plt.ylim((-1, 1));plt.legend(loc='lower right', fontsize=10);plt.draw();plt.pause(0.01)

plt.ioff()
plt.show()
 
我采用了动态绘图的方式,便于时刻观察 GAN 模型训练情况。

迭代次数为 1 时:


迭代次数为 200 时:


迭代次数为 1000 时:


迭代次数为 10000 时:


完美!经过 10000 次迭代训练之后,生成的曲线已经与标准曲线非常接近了。D 的 score 也如预期接近 0.5。

完整代码有 .py 和 .ipynb 两种版本,点击“阅读原文”即可获得。


上一篇:斯坦福大学证明:神经网络能直接在光学芯片上训练!
下一篇:CNN经典论文研读之VGG网络及其tensorflow实现
精选推荐
谷歌《Nature》发论文称实现量子霸权 18亿倍速碾压世界最强超算
谷歌《Nature》发论文称实现量子霸权 18亿倍速碾压世界最强超算

[2019-10-23]  谷歌坚称自己已经取得了量子霸权——这标志着计算研究领域的一个重要里程碑。谷歌首次发布声明是在今年9月,虽然遭到竞争对手的质疑,但就 ...

Waymo:人性和行为心理学才是无人驾驶最大的挑战
Waymo:人性和行为心理学才是无人驾驶最大的挑战

[2019-11-03]  自动驾驶汽车作为AI领域内最大的挑战之一,谷歌致力于其研发已有十余载,现在他们逐渐意识到,最困难的是如何让人们享受驾驶的乐趣。这是一 ...

麻省理工最新机器人“装配工”未来可建造太空基地
麻省理工最新机器人“装配工”未来可建造太空基地

[2019-10-17]  两个机器人原型把一系列小单元组装成大结构体麻省理工学院科研人员最近提出一种新型机器人技术,即一种小型机器人系统,能够自主地用统一规 ...

7种常见的机器人焊接类型
7种常见的机器人焊接类型

[2017-12-17]  机器人焊接是工业领域最常见的机器人应用之一,近几十年来主要由汽车行业驱动。机器人焊接在完成大批量,重复性的焊接任务时效率最高。...

人工智能准确预测患者一年内的死亡风险,原理却无法解释
人工智能准确预测患者一年内的死亡风险,原理却无法解释

[2019-11-13]  图片来自BURGER PHANIE SCIENCE PHOTO LIBRARY美国最新研究显示,人工智能通过查看心脏测试结果,以高达85%以上的准确率预测了一个人在一 ...

谷歌宣布搜索算法重大升级,用BERT模型理解用户搜索意图
谷歌宣布搜索算法重大升级,用BERT模型理解用户搜索意图

[2019-10-26]  谷歌刚刚宣布,其搜索引擎的核心算法正在进行一项重大升级,这项升级可能会改变10%的搜索结果排序。此项升级应用了自然语言处理技术(BERT ...

为未来战场创造更有效的机器人 美国陆军研究人工纳米马达
为未来战场创造更有效的机器人 美国陆军研究人工纳米马达

[2019-10-11]  为了使机器人在战斗中更有效、更多才多艺地成为士兵的战友,美国陆军研究人员正在执行一项任务,即研究肌肉分子生命功能的价值,以及复制过 ...

科学家从蟑螂获得启发 教机器人更好地走路
科学家从蟑螂获得启发 教机器人更好地走路

[2017-12-11]  Weihmann指出:“我特别感到惊讶的是,动物运动稳定机制的变化与腿部协调的变化是一致的。昆虫的慢运行非常稳定,因为它的重心很低,三条腿总是以协调的方式运动。...

本周栏目热点

飞桨火力全开,重磅上线3D模型:PointNet++、PointRCNN!

[2020-03-26]  11 年前的「阿凡达」让少年的我们第一次戴上 3D 眼镜,声势浩大的瀑布奔流而下,星罗棋布飘浮在空中的群山,无一不体现着对生命的敬意, ...

从基础概念到数学公式,这是一份520页的机器学习笔记(图文并茂)

[2018-06-19]  近日,来自SAP(全球第一大商业软件公司)的梁劲(Jim Liang)公开了自己所写的一份 520 页的学习教程(英文版),详细、明了地介绍了机器学习中的相关概念、数学知识和各......

50行代码玩转生成对抗网络GAN模型!(附源码)

[2018-07-30]  本文为大家介绍了生成对抗网络(Generate Adversarial Network,GAN),以最直白的语言来讲解它,最后实现一个简单的 GAN 程序来帮助大家加深理解。...

神经网络和模糊逻辑的工作流

[2016-11-20]   行业观察 神经网络 和模糊逻辑的工作流 null 来源:神州数码erp 发布时间: 2009-10-14 9:06:01 关键词: 工作流,协同,B2B,OA  以下 ...

深度神经网络揭示了大脑喜欢看什么

[2019-11-06]  爱吧机器人网编者按:近日,《自然-神经科学》发表了一篇论文,研究人员创建了一种深度人工神经网络,能够准确预测生物大脑对视觉刺激所产 ...