GAN Note

GAN笔记

为了不被时代所抛弃,最近准备开始好好学一下深度学习的内容,主要针对CV方面,NLP先缓一缓,我还是会站在统计ML的角度来看DL的,今天先来总结一下最优美的GAN。

首先明确GAN是什么,GAN是一种生成模型。讲到这里先说一下生成模型和判别模型,这两个模型最大的区别就是在于求解方式不同,一个是求解\(P(X,Y)\) 联合概率,另一个是求\(P(Y|X)\) 条件概率。就拿之前很火的VAE来说明,我们通过encoder来讲样本encoding,这时候我们得到的embedding可能会服从某个分布\(\mathcal{P}\) ,也就是说我们的样本\(X\) 经过encoder之后分布变得可以进行简单的描述了,比如高斯分布,用均值和方差就可以描述清楚,这样,再经过decoder之后,我们就可以得到我们“伪造”的样本(此处假设你的AE模型已经训练好),所以之后我们就可以用简单的分布\(\mathcal{P}\) 和decoder来生成样本或用encoder来判别等等,GAN也是这样一个思想,只不过是将VAE中的判别部分更改了:VAE是通过衡量输入与输出的误差来训练encoder和decoder的,而GAN,是用一个网络来生成,另一个网络来判别的。

其次要明确为什么需要GAN,GAN的提出是为了解决生成模型的一个痛点问题,如何评判生成模型的好坏,此处引用别人的描述:

1

用GAN的话,就可以很好的规避上面这个问题,因为用一个判别网络(比如一个简单的CNN)会很好地从人的角度(也就是我们通常所说的相同与否)来判别生成样本的质量。

GAN那篇论文写得也很简明,所以我们也开始正文:

生成网络为\(G(z;\theta_g)\) 而判别网络为\(D(x;\theta_d)\) ,生成网络的目标就是生成以假乱真的样本,而判别网络的目标则是区分真的样本与假的样本,我们的目标是得到 \[ \min_G\max_DV(D,G)=\mathbb{E}_{x\sim p_{\text{data}}(x)}[\log D(x)]+ \mathbb{E}_{z\sim p_{z}}[\log(1- D(G(z))] \] 非常简单明了,前半部分是真实样本,后半部分是伪造样本(二分类的似然函数,参考对率回归)。作为一个从统计出发的学习者,一定会问这样一个问题,这个东西会收敛吗?好问题,下面我们就从理论角度来说明一下:

学习GAN模型,还是一个轮流迭代的方式,先给定生成网络\(G\) ,我们有 \[ \begin{aligned} V&=\mathbb{E}_{x\sim p_{\text{data}}(x)}[\log D(x)]+ \mathbb{E}_{x\sim p_{g}}[\log(1- D(x))]\\ &=\int_xp_{\text{data}}(x)\log(D(x))dx+\int_x p_{g}(x)\log(1-D(x))dx\\ &=\int_x[p_{\text{data}}(x)\log(D(x))+ p_{g}(x)\log(1-D(x))]dx \end{aligned} \] 当我们给定x的时候,等同于最优化 \[ f(D)=p_{\text{data}}(x)\log(D(x))+ p_{g}(x)log(1-D(x)) \] 此时对\(D\) 求并令其等于\(0\) \[ \frac{df(D)}{dD}=a\times\frac{1}{D}-b\times\frac{1}{1-D}=0\\ \frac{D}{1-D}=\frac{a}{b}\\ D^*=\frac{a}{a+b}=\frac{p_{\text{data}}(x)}{p_{\text{data}}(x)+p_g(x)} \] 再将最优的\(D\) 代回求解\(V\) \[ \begin{aligned} \max_DV(G,D)=&V(G,D^*)\\ =&\mathbb{E}_{x\sim p_{\text{data}}(x)}[\log\frac{p_{\text{data}}(x)}{p_{\text{data}}(x)+p_g(x)}]+ \mathbb{E}_{x\sim p_{g}}[\log(\frac{p_g(x)}{p_{\text{data}}(x)+p_g(x)})]\\ =&\int_xp_{\text{data}}(x)\log(\frac{p_{\text{data}}(x)}{p_{\text{data}}(x)+p_g(x)})+p_{g}(x)\log(\frac{p_g(x)}{p_{\text{data}}(x)+p_g(x)})dx\\ =&\int_xp_{\text{data}}(x)\log(\frac{p_{\text{data}}(x)}{p_{\text{data}}(x)+p_g(x)})+\log2-\log2\\ &+p_{g}(x)\log(\frac{p_g(x)}{p_{\text{data}}(x)+p_g(x)})dx+\log2-\log2 dx\\ =&\int_xp_{\text{data}}(x)\log(\frac{2p_{\text{data}}(x)}{p_{\text{data}}(x)+p_g(x)})+ p_{g}(x)\log(\frac{2p_g(x)}{p_{\text{data}}(x)+p_g(x)})dx-2log2\\ =&-2log2+\text{KL}(p_{\text{data}}(x)\|\frac{p_{\text{data}}(x)+p_g(x)}{2})+\text{KL}(p_g(x)\|\frac{p_{\text{data}}(x)+p_g(x)}{2})\\ =&-2log2+2\text{JSD}(p_{\text{data}}(x)\|p_g(x))\\ \leq&0 \end{aligned} \] 等号当且仅当\(p_{\text{data}}(x)=p_g(x)\) 时成立。所以我们将生成样本与真是样本相同比例混合送入训练,GAN终将收敛并取得最优,此时得到的“伪造”样本就足够优秀了。同时在训练时,采用判别器训练\(k\) 步,生成器训练\(1\) 步的方式进行迭代,训练算法如下:

2

这种想法的确相当优秀,用另一个网络来做判别,其实看到这里我第一个想到的是reinforcement learning和alphago, 都是采用深度学习来做背后判别,因为显示loss很可能无法刻画一些复杂情况,所以除了将模型部分做deep外,完全也可以把loss判别做成一个deep的模型,我感觉这种想法在以后会越来越多的出现的。哦对了,GAN还有很多细节问题我并没有在这里讲,大家如果想要深入了解GAN还是需要去多读一些相关的论文,有很多特殊的情况会有一些处理的trick。