缘起
Seq2Seq 的 TensorFlow 实现有很多,而 TensorFlow 之前也推出了一套新的 API,文档依旧是令人蛋疼地杂乱。最近在用新的 API 写一个 Many2Many 结构的 Seq2Seq,踩到了一个坑,记录之,也进一步地提醒我应该把主力框架迁移到 PyTorch 提上日程了。
问题描述
根据示例,只要把 encoder 和 decoder 拼接起来,并且使用 TrainingHelper + BasicDecoder:
1 | # 省略参数 |
其中 logits
是 vocab_size 上的分布,shape 为 [batch_size, ?, vocab_size]
而 target
则是正确的输出,这里我将其做了 padding,补齐到最大长度,shape 为 [batch_size, max_len]
,target_mask
是对应的一个权重,非补齐的部分才会参与到 loss 的计算。
如果不 Padding,则会在 feed_dict 这一步报错:
ValueError: setting an array element with a sequence.
因为 target_input 的长度是变长的 ,NumPy 无法将其视作一个 array,导致错误。
然后问题就来了,sequence_loss
函数报错,说其内部调用的 sparse_soft_max
函数的 labels
和 logits
第一维不匹配,其中:logits 的形状为 [?, vocab_size]
;label 的形状为 [batch_size x max_len]
为什么会这样呢?来进一步的分析分析
Source Code
首先是看一下 sequence_loss
的源代码:
1 | with ops.name_scope(name, "sequence_loss", [logits, targets, weights]): |
先是对 logits 和 labels 做了 reshape 操作,之前的形状也能对上,也就是说 logits reshape 之后的 ? = ? x batch_size
,那么这个 ?
究竟是是什么呢,不出意外,应该是生成的序列长度,但为什么是不定长的呢?
来看看 TrainingHelper
的核心源代码:
1 | # TrainingHelper |
next_inputs
函数是 RNN 不断获取下一步的迭代函数,其中 finished
相当于指示了当前的时间步是否已经超出最大长度,即 self._sequence_length
一个记录目标输出长度 int32 向量。reduce_all()
是对某一个维度求逻辑与,如果没有指定 axis 参数,则是对所有元素进行与运算。由此,我们可以得知:当未运行至 target 目标时间步时,会用 Groud-Truth 作为下一个的输入,否则则为 0。但这里并没有给出什么时候停止,所以进一步地,看看 decoder 在哪里调用这个的函数:
1 | def step(self, time, inputs, state, name=None): |
这是 decoder 的 step
函数,相当于对 helper
进一步地封装,以便使用一些功能(TrainingHelper 是训练时 feed groud-truth,在 inference 阶段会使用 GreedyEmbeddingHelper 在无 Groud-Truth 帮助下进行生成等)。再向上,我们去看 dynamic_decode
的函数:
1 | # dynamic_decode 为了节约篇幅,仅保留重要的代码 |
这里,总算看到了我们要的 while loop,循环的控制变量是 finished
,而其又 是 decoder_finished 和 finished 的逻辑或所得到,所以可以得出:当 decoder 解码到 sequence_length 时,其才会停止;另一方面,因为一个 batch 中的长度不都相同,所以得到的 dynamic_length 应该是某个 batch 中最长的一句的长度,到了这里,问题总算是知道根源所在了,那么,怎么解决呢?
Solution
GitHub 上有人提出过这个问题 Issue ,并且有很长的讨论,有一个比较粗暴的解决方案,在无法喂给它 padding 之前的情况下,对 target 做一个截取,因为之前的研究能够让我们确信生成的序列的长度是一定小于等于 batch 中 target 最长的长度的,所以:
1 | # 获取当前的长度,max_len 和 logits 的较小者,事实上,我们可以认为就是 logits 的长度 |
截取之后,就可以保证二者的长度一致,再使用sequence_loss()
计算就可以了。
另外有一个疑问就是,有些代码是可以运行不会报错(网上 Seq2Seq 的教程都是这么写的),猜测是输入数据的格式问题,日后碰到了再提。