SpanBERT

SpanBERT 是由 Facebook AI 在 2019 年提出的一种改进版本的 BERT。它的核心改进点在于 增强对 span(文本片段)的表示能力,从而在涉及 实体识别、关系抽取 等任务时比 BERT 表现更优。模型的改进之处:

  • Span-based Masking(基于 Span 的掩码机制),ERT 采用的是 单个 token 的随机掩码,而 SpanBERT 采用 连续的片段(span)进行掩码。这样能让模型更好地学习跨多个 token 的上下文信息,特别适用于实体识别和关系抽取任务。
  • Span Boundary Objective (SBO)(Span 边界目标),在掩码的 span 中,模型不仅要预测被遮蔽的 token,还要学习该 span 边界的隐藏向量,以便更好地表示整个 span。
  • 无 NSP 任务(Next Sentence Prediction),SpanBERT 移除了 BERT 的 Next Sentence Prediction(NSP)任务,而是更加专注于 句子内部的 span 表示学习

GitHub:https://github.com/facebookresearch/SpanBERT
Paper:https://arxiv.org/abs/1907.10529

从上图中,我们可以看到 SpanBERT 在训练时,首先将输入的序列,进行了随机的连续掩码,例如:

输入:Super Bowl 50 was an American football game to determine the champion
掩码:随机选择输入序列中的 an American football game 进行了掩码

我们要预测 [MASK] 为 football. 在训练过程中,football 的预测损失是由下面两部分计算得到,如下公式所示:

  • 第一部分,是通过将 x7 (对应的是 football) 的编码 R7(编码器的输出向量)送入到分类层中得到预测值和真实值的交叉熵损失
  • 第二部分,是通过将 x7 对应的位置编码向量,以及区间边界的两个 token 的编码器输出送入到分类层得到预测值和真实值的交叉熵损失。这里的边界 token 指的是 span 两侧的 token,比如上图中边界两侧的词指的是 was 和 to,表示出来就是 x4 和 x9,对应的编码器输出位 R4 和 R9。再次注意:计算这部分损失时,输入的是 x7 的位置编码信息 p7,而不是其对应的编码信息 R7。

SpanBERT 模型在训练中的优化目标就是降低该损失,从而获得一个在 QA、RE 任务中表现优秀的 SpanBERT 模型。

上面的 GitHub 中给出了两个 SpanBERT 的英文预训练模型:

  1. SpanBERT (base & cased): 12-layer, 768-hidden, 12-heads , 110M parameters
  2. SpanBERT (large & cased): 24-layer, 1024-hidden, 16-heads, 340M parameters

通过下面的命令下载该模型:

git clone https://huggingface.co/SpanBERT/spanbert-base-cased
git clone https://huggingface.co/SpanBERT/spanbert-large-cased
未经允许不得转载:一亩三分地 » SpanBERT
评论 (0)

6 + 5 =