当前位置: 首页 > news >正文

GAN对抗生成网络(一)——基本原理及数学推导

1 背景

GAN(Generative Adversarial Networks)对抗生成网络是一个很巧妙的模型,它可以用于文字、图像或视频的生成。

例如,以下就是GAN所生成的人脸图像。

2 算法思想

假如你是《古董局中局》的文物造假者(Generator,生成器),希望生产出足以媲美王羲之真迹的假字画。GAN的思想是除了造假者再引入一位鉴宝者(Discriminator,判别器)。造假者的目标是“以假乱真”——精进手艺,从而骗过鉴宝者;鉴宝者的目标是“火眼金睛”——明察秋毫,区别出真字画和假字画。

造假者刚开始不掌握造假知识,只能天马行空地乱写一通,但收到鉴宝者的反馈,调整制作手段,逐步训练出了能够骗过鉴宝者的技艺。

鉴宝者刚开始不掌握鉴宝知识,但造假者的成果源源不断地输送过来,这也成为鉴宝者的学习材料,鉴宝者逐渐提高了鉴定能力。

GAN的全称是Generative Adversarial Networks。从名字也可以了解GAN的特性。

所谓Generative是指GAN是一种生成模型,即产生新的数据,并尽量与真实数据相似。生成模型是与判别模型相对的。

顾名思义,判别模型是为了做判断。训练好的判别模型在给定观测值后,预测标签。如果放到文物造假的例子,判别模型的做法就是找到一批真的文物和一批假的文物,提取出这些文物的特征(比如质地、颜色),学习这些特征如何导致了真或假的结果。这样给定新的文物,就可以根据特征判定文物是真还是假(比如质地为新则为假,质地为旧则为真)。

其实从上文描述可以看出,GAN中的discriminator就是判别模型,但GAN终极目的是为了生成新样本,所以才把GAN归为生成模型,其中的判别模型只是服务于生成模型,属于陪衬。

所谓Adversarial是指GAN是一种对抗模型,要训练出Generator和Discriminator二者相互对抗,你追我赶。

所谓Networks是指GAN是一种神经网络模型,无论是Generator还是Discriminator,他们都需要有足够复杂的表达式,才能实现图片甚至视频的生成,因此需要使用神经网络拟合复杂的运算。

3 数学推导

3.1 符号定义

符号含义文物例子中的对应概念
x样本,例如GAN的目标是生成图片,则此处泛指图像字画
P_{data}(x)真实数据概率分布真字画的技艺
GGenerator,本质是一个神经网络,输入概率分布中的随机噪声,输出为x造假者
G^{*}最优的Generator训练好的造假者
P_{G}(x)Generator对应的概率分布造假者的技艺
z概率分布中的噪音,例如高斯分布抽样出的样本造假者的输入,可理解为当时模具的状态、制作哪幅字画的决定
Div两个分布之间的距离真实制作工艺和造假制作工艺的区别
DDiscriminator,本质是一个神经网络,输入一个样本x,输出为1或0,代表x是真实的概率鉴宝者
D^{*}最优的Discriminator训练好的鉴宝者
V优化目标,越大越好真品和赝品的相似程度,或者赝品能卖出的价钱
\theta_GGenerator的参数P_{G}概率分布的参数
\theta_{D}Discriminator的参数鉴宝者鉴宝所遵从的规则参数,比如质地有“多旧”才会认为是真品
\eta误差反向传播时神经网络的学习率造假者或鉴宝者更新自己造假技艺和更新技艺时调整的幅度

3.2 解析

以上的例子中造假者的目标是学习出一套接近真实的技艺,但技艺太抽象了,需要用数学的语言表示,这里使用“概率分布”表示技艺。可能不好理解,解释如下。

第一,这里的一张字画可比作训练的样本,那么生产字画的技艺自然就可以比作产生样本的概率分布。

第二,无论是生成图片还是生成字画的过程,本质上就是从概率分布中抽样的过程。

例如我们要生成一个足球图片,每个像素点是否应该为空,应该呈现什么颜色,都是有概率的。我们应该学习到这么一种概率分布:在一个画布上,产生圆形的概率最高,这一点毋庸置疑;产生椭圆形的概率次高,可能有时存在镜头畸变现象,或者画面想表现出足球快速向前运动(参考下图中的热血足球);产生正方形的概率几乎为0。此外,颜色也是用概率分布产生的,黑色和白色概率最高,但不排除有其他颜色的情况。

可能会有读者问,在生成字画的例子中,不同文章的字完全不一样,怎么能用一种概率分布来表示呢?这里的概率分布可能是非常复杂的分布,例如分解成先用均匀分布选择画哪幅字,然后用条件概率再决定画布每个字颜色、落笔力度、方向等。

再举一个更简单的例子,比如给猴子一台有26个英文字母的打字机,它任意地敲击100个字符,那么概率分布空间就有100^{26}种,属于一维均匀分布。猴子每次敲击出的文章就是从这个概率分布中抽样一次。在这个例子中,打字机的规则十分清晰,即英文字母按顺序排列,因此比创作字画这种高维分布简单得多。

第三,只有引入概率分布才能生成大量样本。这也对应开篇所说的造假者目标,是学习一种技艺,而不是学习单一的某张字画。

如果理解了概率分布,那么也就明白了“符号定义”中的z。它代表了造假者的输入,例如要决定画哪幅字画,当时画笔的状态等等随机的因素。z一般是高斯分布,经过G的转换得到x_{1},服从更复杂的分布,这种分布就代表技艺。

3.3 推导

3.3.1 Discriminator

先说结论,推导后的最优解等价于JS散度最大。JS散度表示2个分布之间的距离,越大表示2个分布的差异越明显。这也符合Discriminator的目标,就是要让真品和赝品的区别最明显。

推导过程如下,不感兴趣的读者可略过。

对于D而言,优化目标是下式最大化。

V(G,D)=E_{x\sim P_{data}}[logD(x)]+E_{x\sim P_{G}}[log(1-D(x))]

左侧代表样本来自于真实数据时,D预测为真实的概率,此值越大越好;右侧表示当样本来自于假数据时,D预测为真实的概率,此值越小越好。取对数不影响单调性。

将此式的期望转换为积分的形式

V(G,D)=\int_{x}{P_{data}(x)\log{D(x)}dx}+\int_{x}{P_{G}(x)\log{(1-D(x))}dx}=\int_{x}{[P_{data}(x)\log{D(x)}+P_{G}(x)log{(1-D(x))}]dx}

现在引入一个强假设——D(x)可以是任意的函数。意思是对于每一个xD(x)足够复杂,可以想象成D(x)是一个分段函数,使得每个x上式中的积分项里的内容都最大。因此问题转化为积分项内的式子最大。因为D本质上是神经网络,足够复杂,所以是合理的。

为了方便说明,记a=P_{data}(x)b=P_{G}(x),D=D(x),V(G,D)中的每一个积分项都是f(D)。注意当优化D时,G是给定的,因此可看作常量。

上式转化为f(D)=a\log{(D)}+b\log{(1-D)}

为了求f的最大值,求上式微分等于0时的D(二阶微分小于0,因此可断定f为最大值而非最小值)

\frac{\mathrm{d} f(D) }{\mathrm{d} D}=\frac{a}{D}-\frac{b}{1-D}=0

可得D^*(x)=\frac{a}{a+b}=\frac{P_{data}(x)}{P_{data}(x)+P_{G}(x)}. 代入V(G,D),得

\max\limits_{D}V(G,D) =V(G,D^*)=\int_{x}{[P_{data}(x)\log{\frac{P_{data}(x)}{P_{data}(x)+P_{G}(x)}}+P_{G}(x)\log{\frac{P_{G}(x)}{P_{data}(x)+P_{G}(x)}}]dx}

在对数式分子和分母同时乘以1/2,并将分子的1/2提出来,得到

V(G,D^*)=-2\log{2}+\int_{x}{[P_{data}(x)\log{\frac{P_{data}(x)}{\left(P_{data}(x)+P_{G}(x)\right)/2}}+P_{G}(x)\log{\frac{P_{G}(x)}{\left( P_{data}(x)+P_{G}(x)\right)/2}}]dx}

将中括号内的求和拆成2项,每一项都是KL散度,合起来就是JS散度

V(G,D^*)=-2\log{2}+KL\left(P_{data}||\frac{P_{data}+P_{G}}{2}\right)+KL\left(P_{G}||\frac{P_{data}+P_{G}}{2}\right)=-2\log{2}+2JS(P_{data}||P_{G})

GAN神奇的地方在于,在我们不清楚真实分布P_{data}参数的情况下,D也能够用抽样的方式判断分布之间的距离,从而使得G趋近这个分布。

3.3.2 Generator

对于G而言,要找出最佳的G^{*},使得P_{G}P_{data}越接近越好,即

G^{*}=\arg\min\limits_{G}Div(P_{G},P_{data})

根据上文所述,这里Div(P_{G},P_{data})就是\max\limits_{D}V(G,D),即G^{*}=\arg\min\limits_{G}\max\limits_{D}V(G,D)。这就转化成G在知道D会选择让真实分布和生成分布距离最大时,应该做出什么样的决策,使得两个分布的距离最小。

这里借用李宏毅老师的例子并做了些修改:G只有2种选择,G_{1}G_{2},那么G应该选择哪一种呢?对应下图,左右分别是2种选择,横轴代表D的决策空间,纵轴代表真实分布和生成分布的距离。

如果选择第一种,D一定会选择让V最大的参数D_{1}^{*},此时V的值对应左图红点的纵坐标。同理,如果选择第二种,则V的值对应右图红点的纵坐标。由于左图的红点纵坐标比右图高,所以G选择第二个。

从另一个角度理解,仔细观察下图,左图的蓝色线整体比较低,只有一处尖峰位置较高;右图的蓝色线整体比较高。还是回到文物造假的例子,造假者有2种造假手段,第一种造出的赝品整体和真品差距不大,但是有一处非常容易露馅(比如整体的字都和王羲之的字很像,但只有某一个字的某一个笔画和王羲之相去甚远);第二种造出的赝品和真品的差距不大不小,没有明显的破绽。造假者和鉴宝者博弈,知道鉴宝者会揪住很小的破绽不放,因此选择了更稳妥的第二种造假手段。

下一篇笔者介绍GAN的具体算法和Python实现代码。

参考资料

李宏毅对抗生成网络2018

B站机器学习白板推导

《进阶详解KL散度》https://zhuanlan.zhihu.com/p/372835186


http://www.mrgr.cn/news/81854.html

相关文章:

  • 速卖通AliExpress商品详情API接口深度解析与实战应用
  • 金融租赁系统的创新发展与市场竞争力提升探讨
  • java实现预览服务器文件,不进行下载,并增加水印效果
  • QT------------QT框架中的模块
  • Java - 日志体系_Apache Commons Logging(JCL)日志接口库_适配Log4j2 及 源码分析
  • Lua语言的文件操作
  • CSS 居中技术完全指南:从基础到高级应用
  • Java重要面试名词整理(十二):Netty
  • Windows Knowledge
  • RTLinux和RTOS基本知识
  • Oracle数据库中用View的好处
  • Doris使用注意点
  • java相关学习文档或网站整理
  • 小程序基础 —— 02 微信小程序账号注册
  • GDPU 数据库原理 期末复习(持续更新……)
  • 微信小程序 app.json 配置文件解析与应用
  • 小程序基础 —— 08 文件和目录结构
  • mybatis基础学习
  • 小程序配置文件 —— 13 全局配置 - window配置
  • csrf跨站请求伪造(portswigger)无防御措施
  • 小程序配置文件 —— 12 全局配置 - pages配置
  • springMVC-请求响应
  • 数据分析与应用:如何分析7日动销率和滞销率?
  • 经典问题——华测
  • 【论文阅读】Reducing Activation Recomputation in Large Transformer Models