SpanBERT

SpanBERT 是一个为了能够更好的表征和预测一个范围的文本的模型。Bert 模型是通过随机掩码一些 token 来对模型进行预训练,而 SpanBERT 不同的是,它通过随机掩码一个范围的文本对模型进行预训练。比如:像问答问题、关系抽取任务等,其特点都是从输入的文本中抽取一个 span 作为预测结果。所以,对于这些问题,SpanBERT 会有非常好的表现。

SpanBERT GitHub:GitHub – facebookresearch/SpanBERT: Code for using and evaluating SpanBERT
SpanBERT Paper:SpanBERT: Improving Pre-training by Representing and Predicting Spans

SpanBERT 主要提出了 SBO(span-boundary objective,区间便捷目标)的预训练任务。该任务如下图所示:

从上图中,我们可以看到 SpanBERT 在训练时,首先将输入的序列,进行了随机的连续掩码,例如:

输入:Super Bowl 50 was American football game to determine the champion
掩码:随机选择输入序列中的 an American football game 进行了掩码

SBO 的目标是预测被掩码的词是什么?比如上图中,我们要预测 [MASK] 为 football. 在训练过程中,football 的预测损失是由下面两部分计算得到,如下公式所示:

第一部分,是通过将 x7 (对应的是 football) 的编码 R7(编码器的输出向量)送入到分类层中得到预测值和真实值的交叉熵损失

第二部分,是通过将 x7 对应的位置编码向量,以及区间边界的两个 token 的编码器输出送入到分类层得到预测值和真实值的交叉熵损失。这里的边界 token 指的是 span 两侧的 token,比如上图中边界两侧的词指的是 was 和 to,表示出来就是 x4 和 x9,对应的编码器输出位 R4 和 R9。再次注意:计算这部分损失时,输入的是 x7 的位置编码信息 p7,而不是其对应的编码信息 R7。

SpanBERT 模型在训练中的优化目标就是降低该损失,从而获得一个在 QA、RE 任务中表现优秀的 SpanBERT 模型。

上面的 GitHub 中给出了两个 SpanBERT 的英文预训练模型:

  1. SpanBERT (base & cased): 12-layer, 768-hidden, 12-heads , 110M parameters
  2. SpanBERT (large & cased): 24-layer, 1024-hidden, 16-heads, 340M parameters

通过下面的命令下载该模型:

git clone https://huggingface.co/SpanBERT/spanbert-base-cased
git clone https://huggingface.co/SpanBERT/spanbert-large-cased
未经允许不得转载:一亩三分地 » SpanBERT
评论 (0)

3 + 4 =