GAN对抗生成网络(一)——基本原理及数学推导
1 背景
GAN(Generative Adversarial Networks)对抗生成网络是一个很巧妙的模型,它可以用于文字、图像或视频的生成。
例如,以下就是GAN所生成的人脸图像。
2 算法思想
假如你是《古董局中局》的文物造假者(Generator,生成器),希望生产出足以媲美王羲之真迹的假字画。GAN的思想是除了造假者再引入一位鉴宝者(Discriminator,判别器)。造假者的目标是“以假乱真”——精进手艺,从而骗过鉴宝者;鉴宝者的目标是“火眼金睛”——明察秋毫,区别出真字画和假字画。
造假者刚开始不掌握造假知识,只能天马行空地乱写一通,但收到鉴宝者的反馈,调整制作手段,逐步训练出了能够骗过鉴宝者的技艺。
鉴宝者刚开始不掌握鉴宝知识,但造假者的成果源源不断地输送过来,这也成为鉴宝者的学习材料,鉴宝者逐渐提高了鉴定能力。
GAN的全称是Generative Adversarial Networks。从名字也可以了解GAN的特性。
所谓Generative是指GAN是一种生成模型,即产生新的数据,并尽量与真实数据相似。生成模型是与判别模型相对的。
顾名思义,判别模型是为了做判断。训练好的判别模型在给定观测值后,预测标签。如果放到文物造假的例子,判别模型的做法就是找到一批真的文物和一批假的文物,提取出这些文物的特征(比如质地、颜色),学习这些特征如何导致了真或假的结果。这样给定新的文物,就可以根据特征判定文物是真还是假(比如质地为新则为假,质地为旧则为真)。
其实从上文描述可以看出,GAN中的discriminator就是判别模型,但GAN终极目的是为了生成新样本,所以才把GAN归为生成模型,其中的判别模型只是服务于生成模型,属于陪衬。
所谓Adversarial是指GAN是一种对抗模型,要训练出Generator和Discriminator二者相互对抗,你追我赶。
所谓Networks是指GAN是一种神经网络模型,无论是Generator还是Discriminator,他们都需要有足够复杂的表达式,才能实现图片甚至视频的生成,因此需要使用神经网络拟合复杂的运算。
3 数学推导
3.1 符号定义
符号 | 含义 | 文物例子中的对应概念 |
样本,例如GAN的目标是生成图片,则此处泛指图像 | 字画 | |
真实数据概率分布 | 真字画的技艺 | |
Generator,本质是一个神经网络,输入概率分布中的随机噪声,输出为x | 造假者 | |
最优的Generator | 训练好的造假者 | |
Generator对应的概率分布 | 造假者的技艺 | |
概率分布中的噪音,例如高斯分布抽样出的样本 | 造假者的输入,可理解为当时模具的状态、制作哪幅字画的决定 | |
两个分布之间的距离 | 真实制作工艺和造假制作工艺的区别 | |
Discriminator,本质是一个神经网络,输入一个样本x,输出为1或0,代表x是真实的概率 | 鉴宝者 | |
最优的Discriminator | 训练好的鉴宝者 | |
优化目标,越大越好 | 真品和赝品的相似程度,或者赝品能卖出的价钱 | |
Generator的参数 | 概率分布的参数 | |
Discriminator的参数 | 鉴宝者鉴宝所遵从的规则参数,比如质地有“多旧”才会认为是真品 | |
误差反向传播时神经网络的学习率 | 造假者或鉴宝者更新自己造假技艺和更新技艺时调整的幅度 |
3.2 解析
以上的例子中造假者的目标是学习出一套接近真实的技艺,但技艺太抽象了,需要用数学的语言表示,这里使用“概率分布”表示技艺。可能不好理解,解释如下。
第一,这里的一张字画可比作训练的样本,那么生产字画的技艺自然就可以比作产生样本的概率分布。
第二,无论是生成图片还是生成字画的过程,本质上就是从概率分布中抽样的过程。
例如我们要生成一个足球图片,每个像素点是否应该为空,应该呈现什么颜色,都是有概率的。我们应该学习到这么一种概率分布:在一个画布上,产生圆形的概率最高,这一点毋庸置疑;产生椭圆形的概率次高,可能有时存在镜头畸变现象,或者画面想表现出足球快速向前运动(参考下图中的热血足球);产生正方形的概率几乎为0。此外,颜色也是用概率分布产生的,黑色和白色概率最高,但不排除有其他颜色的情况。
可能会有读者问,在生成字画的例子中,不同文章的字完全不一样,怎么能用一种概率分布来表示呢?这里的概率分布可能是非常复杂的分布,例如分解成先用均匀分布选择画哪幅字,然后用条件概率再决定画布每个字颜色、落笔力度、方向等。
再举一个更简单的例子,比如给猴子一台有26个英文字母的打字机,它任意地敲击100个字符,那么概率分布空间就有种,属于一维均匀分布。猴子每次敲击出的文章就是从这个概率分布中抽样一次。在这个例子中,打字机的规则十分清晰,即英文字母按顺序排列,因此比创作字画这种高维分布简单得多。
第三,只有引入概率分布才能生成大量样本。这也对应开篇所说的造假者目标,是学习一种技艺,而不是学习单一的某张字画。
如果理解了概率分布,那么也就明白了“符号定义”中的。它代表了造假者的输入,例如要决定画哪幅字画,当时画笔的状态等等随机的因素。一般是高斯分布,经过的转换得到,服从更复杂的分布,这种分布就代表技艺。
3.3 推导
3.3.1 Discriminator
先说结论,推导后的最优解等价于JS散度最大。JS散度表示2个分布之间的距离,越大表示2个分布的差异越明显。这也符合Discriminator的目标,就是要让真品和赝品的区别最明显。
推导过程如下,不感兴趣的读者可略过。
对于而言,优化目标是下式最大化。
左侧代表样本来自于真实数据时,D预测为真实的概率,此值越大越好;右侧表示当样本来自于假数据时,D预测为真实的概率,此值越小越好。取对数不影响单调性。
将此式的期望转换为积分的形式
现在引入一个强假设——可以是任意的函数。意思是对于每一个, 足够复杂,可以想象成是一个分段函数,使得每个上式中的积分项里的内容都最大。因此问题转化为积分项内的式子最大。因为本质上是神经网络,足够复杂,所以是合理的。
为了方便说明,记, ,,中的每一个积分项都是。注意当优化D时,G是给定的,因此可看作常量。
上式转化为
为了求f的最大值,求上式微分等于0时的D(二阶微分小于0,因此可断定f为最大值而非最小值)
可得. 代入,得
在对数式分子和分母同时乘以1/2,并将分子的1/2提出来,得到
将中括号内的求和拆成2项,每一项都是KL散度,合起来就是JS散度
GAN神奇的地方在于,在我们不清楚真实分布参数的情况下,也能够用抽样的方式判断分布之间的距离,从而使得趋近这个分布。
3.3.2 Generator
对于而言,要找出最佳的,使得和越接近越好,即
根据上文所述,这里就是,即。这就转化成在知道会选择让真实分布和生成分布距离最大时,应该做出什么样的决策,使得两个分布的距离最小。
这里借用李宏毅老师的例子并做了些修改:只有2种选择,和,那么应该选择哪一种呢?对应下图,左右分别是2种选择,横轴代表的决策空间,纵轴代表真实分布和生成分布的距离。
如果选择第一种,一定会选择让V最大的参数,此时V的值对应左图红点的纵坐标。同理,如果选择第二种,则V的值对应右图红点的纵坐标。由于左图的红点纵坐标比右图高,所以选择第二个。
从另一个角度理解,仔细观察下图,左图的蓝色线整体比较低,只有一处尖峰位置较高;右图的蓝色线整体比较高。还是回到文物造假的例子,造假者有2种造假手段,第一种造出的赝品整体和真品差距不大,但是有一处非常容易露馅(比如整体的字都和王羲之的字很像,但只有某一个字的某一个笔画和王羲之相去甚远);第二种造出的赝品和真品的差距不大不小,没有明显的破绽。造假者和鉴宝者博弈,知道鉴宝者会揪住很小的破绽不放,因此选择了更稳妥的第二种造假手段。
下一篇笔者介绍GAN的具体算法和Python实现代码。
参考资料
李宏毅对抗生成网络2018
B站机器学习白板推导
《进阶详解KL散度》https://zhuanlan.zhihu.com/p/372835186