XLNet

Bert 在预训练时,会将输入中一部分的 Token 随机替换为 MASK 标记, 即:在预训练阶段输入中存在 MASK 标记,但在微调时,输入的文本中是不包括 MASK 标记。这就是 Bert 预训练和微调时不匹配,会影响到模型的性能。并且 Bert 对输入有 max_length 的限制。

Paper:https://arxiv.org/pdf/1906.08237.pdf

XLNET 就想:

  1. 在预训练时,去掉 MASK 标记,同时希望能够保留 Bert 利用上文语义的优点;
  2. 希望能够支持更长的文本输入,能够更好的利用长文本的依赖关系。

XLNET 结束长文本输入限制时,参考了 Transformer-XL 的段循环机制和相对位置编码技术。而针对第一个想法则提出了两个创新点:

  1. Permutation Language Modeling, PLM,中文有多个类似的不同的叫法:排列语言模型、有序因子排列、重排语言模型
  2. Two-Stream Self Attention 双流自注意力机制,这个机制的提出主要是针对使用 PLM 时,会带来一些问题。

1. Permutation Language Modeling

Bert 采用 AE 语言模型,该模型能够利用输入上下文。不同于 Bert 模型,XLNET 采用的是 AR 语言模型。我们都知道 AR 语言模型只能利用单向语义,或只能利用上文信息,或利用下文信息,但是不能利用上下文信息。

XLNET 的想法是既能利用 AR 语言模型,也能够利用到输入上下文。所以,对 AR 语言模型进行改进。其改进的方式思路很简单,就是对输入的序列进行排序,这样预测位置下文的内容有可能会出现在上文中,此时 AR 就可以利用上下文信息。

例如,输入一个序列 [x1, x2, x3, x4], 将其重新排列的话会有 \(A^4_4\) 种,我们选择其中几个排列方式来分析下,AR 如何通过对输入重排来利用上文:

  1. 假设排序后的序列为:[x3, x2, x4, x1],我们要预测 x3,但是 x3 前面没有任何内容,所以无法利用任何其他信息。
  2. 假设排序后的序列为:[x2, x4, x3, x1],我们要预测 x3,此时 x3 左侧有 x2 和 x4,此时我们就可以利用了上文的 x2 和 下文的 x4 信息。

在理解上面例子时,一定要清楚,我们使用的是 AR 语言模型,它只能利用到上文的信息。重新对输入进行排列之后,原输入序列中的下文 Token 就可能会出现在上文中,实现了利用上下文。

我们简单对比下 Bert MASK 和 XLNET PLM:

输入:我爱北京天安门

  1. Bert 模型通过 MASK 方式,我爱[MASK]京天[MASK]门,在预测第一个 MASK 时,可以使用到 “我爱京天门” 等上下文信息,由于输入中除了当前的字词被掩盖,还有其他词也被掩盖了,所以 Bert 不是利用了其他所有的上下文。
  2. XLNET 通过 PLM 对输入进行重排:我爱京天门[北]安, 我们发现将 “我爱京天门” 排序到了上文,同样实现了在 AR 基础上,预测 “北” 时利用到了部分上下文信息。

所以,我们把 PLM 可以理解为 XLNET 对 Bert MASK 方式的一种优化或者改进,目的就是解决训练时多了一个 MASK,而微调时没有 MASK 的问题,实现预训练过程和微调过程的一致。

XLNET 并不是将输入序列真的打乱来实现重排的,而是通过 attention mask 来实现的,假设:我们重排得到的一个序列是:3->2->4->1,XLNET 会随机构建的掩码矩阵如下:

第一行有 4 个红点,表示第一个 token 1 会利用 1、2、3、4 token 的信息;
第二行有 2 个红点,表示第二个 token 2 会利用 2、3 token 的信息;
第三行有 1 个红点,表示第三个 token 3 会利用 3 token 的信息;
第四行有 3 个红点,表示第四个 token 4 会利用 2、3、4 token 的信息。

就像我们前面说过的,包含 4 个 token 的输入序列,重排之后会有 \(A^4_4\) 种排列组合,也是说会存在\(A^4_4\) 种上面的 attention mask 矩阵。当我们预测某个词时,不会选择所有的组合,而是会进行采样。并且进行预测时,为了训练效率的考虑,也不会预测所有的 token,而是采用 Partial Prediction,即:只对后 1/k 的 token 进行预测

这就类似 Bert mask 时的序列有很多种,我们只选择其中的一种或者多种。

2. Two-Stream Self Attention

当 Bert 预测某个位置词的时候,Bert 使用 [MASK] token 来代表该词,使得模型看不到该词的信息。然后自注意力机制计算当前词的预测张量(就是 MASK token 计算得到的张量)。

当 XLNet 对输入 PLM 时,由于没有 mask 的存在,输入的 token 都包含了位置信息和内容信息。那么当预测某个 token 时,就出现了输入当前 token 的信息预测 token 的情况,这个显然不合理。

XLNET 就是使用双流自注意力机制来解决这个问题,所谓的双流指的是内容流(content stream)和查询流(query stream),这个内容流就是标准的自注意力机制,content stream 的作用如下图所示:

训练的时候我们需要对每个 token 都要进行语义学习,这就是通过内容流来完成的,其实也就是标准自注意力机制的方式来进行计算的。计算过程如下图所示:

根据 content stream 的掩码矩阵,我们知道计算其语义张量时,需要依赖 1、2、3、4 token。

我们前面在学习 Bert 的时候,是通过对 mask token 进行预测来学习 token 与 token 之间的语义依赖关系。在 XLNet 里需要进行预测,这个预测是通过查询流来完成的,如下图所示:

先看图右侧多了一个查询流注意力掩码矩阵,对角线被圈起来了。这表示当预测第一个 token 时,不允许看到第一个 token 的 content,只能看到 2、3、4 token 的 content。

这里的 \(g^{(0)}_{1}\) 指的就是前一个图中的 w(初始化为一个可学习的参数),通过查询流掩码知道如何获得当前预测的张量,该张量就可以送入输出层得到预测结果。

论文中给出的完整的示意图如下:

左上角是内容流,左下角是查询流,右侧为双流注意力的计算过程。这里需要注意的是,我们在输入时不会打乱文本,这个打乱的操作是由 XLNET 内部自己做的,其内部是通过产生 Attention Mask 来实现打乱,Attention Mask 有查询流的 Mask,也有内容流的 Mask。不同的排列都会对应不同的 Attention Mask。

双流自主力机制会对每个 token 产生两个结果:一个是根据其他位置的词对当前位置的预测输出 g,另一个则是根据所有词的依赖来获得当前词的表征 h。

由 content stream 产生词的表征,由 query stream 得到对某个位置 token 的预测。query stream 的预测需要 content stream 产生的词的表征。

到这里的话,也能理解,在预训练阶段,我们需要双流来支持在大语料上的无监督训练。微调时,只需要内容流就可以了。

未经允许不得转载:一亩三分地 » XLNet