MaskGAN 是 Goodfellow 组的新作,已经被 ICLR 2018 接收,标题很是风骚,MaskGAN: Better Text Generation via Filling in the ____,这个下划线的操作真是… astounding。代码也已开源。这篇文章依旧是熟悉的套路,从模型 + 代码来解读论文,走起!
SeqGAN To MaskGAN
SeqGAN 的缺点
上一篇讲 SeqGAN 的时候我们提到,SeqGAN 开创了 GAN 在 Text Generation 的先河,但是,实验结果证明,其 Idea 是能 Work(通过强化学习解决 GAN 无法在离散文本上梯度回传),合成数据中的 loss 确实有下降,但是在真实的古诗数据集上,其生成的文本质量不如人意。我利用全唐诗做了实验,不过囿于设备和时间原因,并没有充分的训练和调优,摘录部分生成结果如下:
霞畅拍起妇 已煦肃兢恼 鶋仝棚愕迷 啼肃次念云
岂阳孤任帐 因伊牧掩牢 人原马槎问 弥章斗天钓
鸡行肩始昏 晨刺重云千 指瘼山月堂 一似蕃德率
有足偶有欲 威飏欢浩潋 戏鸟靓簪粘 性负觉狄至
有没有一种狗屁不通的感觉… 反正我是很绝望。
也就是说 SeqGAN 效果不是很好(也有实验室做过实验,其中生成质量较好的古诗基本都是训练集中的),而 MaskGAN 可能为提升生成文本的质量指出了一个方向,其和 SeqGAN 有两点主要的区别:
- 增加额外的 Information,Masked Sequence $m(x)$,这也导致了其使用的模型架构变成 Seq2Seq,而非 SeqGAN 中 LSTM(Generator)和 CNN(Discriminator)
- 使用 Actor-Critic 来进行强化学习,而非 SeqGAN 中的 Policy Gradient + Monte Carlo
接下来我们就从这两点不同入手,来讲解 MaskGAN。
Masked Token
MaskGAN 在文中指出了 GAN 的两个问题,一是 Mode Collapse,即可能出现少数的生成样本种类占据了整个生成集,缺乏多样性;二是训练不稳定,GAN 难调试是出了名的。文章解决这两个问题的方案是:不再让生成器来生成的完整的文本,而是做“完形填空”,不过关于为什么能解决,他们是这么说的:
We believe the in-filling may mitigate the problem of severe mode-collapse.
一个believe
再来 may
加上一个 mitigate
,这就是论文的表述的艺术啊。解决训练不稳定的方法呢就是从 Policy Gradient 转换为 Actor-Critic,后面再说。
“完形填空”相信大家高中都做过,就是把文章挖空然后让你选一个正确的单词填进去,MaskGAN 就是这么干的,对于一个输入序列 $x = (x_1,…, x_T)$,经过一个 mask: $m=(m_1,…m_T)$,其中 $m_i$ 的取值为 0 或者 1,0 就代表挖掉,1 就意味着保留。经过挖空的操作之后呢,我们就得到了 Masked Token $m(x)$,并将它交给我们以 Seq2Seq 为架构的 Generator 来进行生成,模型见下:
需要注意一点就是:生成的 token 不一定会作为下一个生成的 pre-token,而是取决于是否被挖空,如有原;这也是一个重要的细节,因为一个错误的答案可能会导致一整篇文章都是错误的,所以,如果有参考答案还是用参考答案。
Discriminator 的架构也是采用的 Seq2Seq,只不过是 many2one,即最后生成的每个 token 为真的概率。除了有填好的句子做为输入以外,$m(x)$ 也作为 Discriminator 的输入,文章是这么解释这么做的原因的:对于一个生成的句子 the director director guided the series
,如果没有 $m(x)$ 的话,那么判别器无法分别到底前一个 director
是原文呢还是后一个是,因为句子有可能是 the *associate* director guided the series
或者是 the director *expertly* guided the series
,因此是有必要给判别器关于原文的信息,从而做出更好的判断。生成器和判别器的公式如下:
Actor-Critic
前面的文章谈到了,AC 的做法相比 Policy Gradient,很大的区别就在是单步更新,以及用一个 NN 来拟合 Advantage Function 来指导生成器生成更加逼真的文本。MaskGAN 的单步reward $r_t$ 设置为了 log probablity,也就是:
$$ r_t = log D_\phi(\hat{x_t}|\hat{x}_{0:T}, \ \textbf{m(x)})$$
总的 reward $R_t$ 则为这一时刻到句子结束 T 时之和:
$$R_t = \sum_{s=t}^T \gamma ^s r_s$$
我们通过减去一个 critic 产生的 baseline $b_t$ 来降低 variance,更新梯度的计算就变成:
其中 $b_t$ 由一个 NN 来拟合,MaskGAN 选择使用 Discriminator 的前半部分来估计 $b_t$,详细说明需要结合代码来进行。
Code Matters
代码永远是一个很好的学习材料,也是检验论文到底是不是糊弄人的试金石。
Generator
从代码里我们可以看到,作者实现了很多种 Generator 的架构,有 CNN、RNN 和 Seq2Seq,所以 Seq2Seq 应该是经过对比之后选出来效果比较好的一种。
先来看 Encoder 部分,其作用是把 Masked Token 交给一个 LSTM:
1 | def gen_encoder(hparams, inputs, targets_present, is_training, reuse=None): |
思路是这样的:输入的 Inputs,根据 targets_present(一个 bool 向量指示是否 mask)进行 mask 操作,然后丢进 RNN 里面,得到最后的 state 作为输出。
但这个代码里看到了几个 tricks:
- 在 RNN Cell 外再包了一层 Variational Dropout,每个 Unit 的 Dropout Rate 也是随机产生的,而不再是定值。推测是想要加强 Regularization 的作用,学到了。
- 在 Masked 的 Inputs 上进行一次 Encode 得到一个 final_masked_state 后,又在 Origin Input 上做了一次 Encode 得到 final_state,还不知道是干嘛用的,稍后再看。
接下来是 Decoder 部分:
1 | def gen_decoder(hparams, |
Decoder 的思路也是很直接,就是用 Encoder 传入的 state tuple,进行 token 的生成。有几点需要注意的是:
- 和论文中一样,如果有 real token,那么 real token 就会作为下一个 token 的 input,而非使用生成的 token
- 作者在设计的时候考虑到了使用 MLE 进行预训练的情况,这时候就全部使用 real tokens,并基于此生成一句话
Discriminator
文章说 Discriminator 的架构和 Generator 架构是一样的,只是最后输出是一个 scalar:
1 | with tf.variable_scope('dis', reuse=reuse): |
确实,dis_encoder
部分的代码和 gen_encoder
是一致的,实现也是类似的;
而 dis_decoder
:
1 | with tf.variable_scope('rnn') as vs: |
对于输入 sequence
,进行 embedding 后,拿出里面的每一个 token,交给 RNN,输出一个 probability,没问题!
Critic
论文中有提到一嘴,就是说 AC 这个算法是后来审稿人提出意见之后再加的。我一开始还担心代码里没有,但 Google 还是做的很不错的:
1 | def critic_seq2seq_vd_derivative(hparams, sequence, is_training, reuse=None): |
和文中所说的 head of discriminator
一致,代码中 Critic 的实现就是前半部分的 Discriminator,并且复用了 Discriminator 的参数,最后输出也就是一个 scalar,每个 token 的奖励 value;
Objective Function
我一开始以为公式中的 $r_t$ 是要计算每个 time step 的,但论文中的注释中说:
The REINFORCE objective should only be on the tokens that were missing. Specifically, the final Generator reward should be based on the Discriminator predictions on missing tokens.
The log probaibilities should be only for missing tokens and the baseline should be calculated only on the missing tokens.
也就是说,只在 missing tokens 上计算相应的 reward。这也很简单,对输出的进行一个 mask 就行:
1 | # Generator rewards are log-probabilities. |
代码中还实现了很多种 baseline,这里我们只看 Critic 作为 baseline 的情况:
1 | if FLAGS.baseline_method == 'critic': |
这段代码就对应上面的 $R_t - b_t$,只是这里的 $b_t$ 是由 Critic 产生的;剩下的几种 $b_t$ 里还有一半 Monte Carlo,一半 Critic 的情况,就不再细说。
Training
训练的套路呢也是类似的,先让预训练 Generator,再是进入 GAN 的一个对抗训练过程之中:
1 | # pretraining |
值得一提的是,loss 是 forward 和 backword 的平均值,这点似乎论文中并没有提到,算是作者的一个小心机?2333
Summary
这篇文章存了有一个礼拜才写完,总算是赶完了;代码部分读的还是很粗糙,接下来会继续把几篇 GAN + NLP 的文章好好读一下写笔记,跑个 Demo,然后试着写一篇 Overview 出来看看能不能忽悠住大家(逃