Longformer

Tansformer-based 的模型都是基于自注意力机制,我们知道自注意力机制擅长捕捉输入 Token 内部相关性,并以此能够建立对 Token 的表征。但是自注意力机制随着输入序列长度的增加,所需的计算量也变得很大。这就使得对 transformer 处理更长的文本输入开销很大。当然,还有存储空间的需求也会更大。这也是为什么我们说 transformer 不适合处理长文本的原因。

我们通过一个例子来理解下计算量的大小。比如:

  1. 首先,输入序列长度是 512,维度为 768
  2. 然后,输入序列长度是 1024,维度为 768
  3. 对比单条样本不同输入长度的所需的计算量

我们这里的实验使用的是 ptflops 库里的 get_model_complexity_info 方法来完成,示例代码如下:

from transformers import BertModel
from transformers import BertConfig
from ptflops import get_model_complexity_info


def test():
    # 初始化 Bert 模型
    model = BertModel(config=BertConfig())
    # 获得 Bert 自注意力计算对象
    self_attn = model.encoder.layer[0].attention.self
    print(self_attn)
    print('-' * 50)

    # 计算输入长度为 512 时,模型的计算量
    flops, params = get_model_complexity_info(model=self_attn, input_res=(512, 768))
    print(flops, params)
    print('-' * 50)

    # 计算输入长度为 1024 时,模型的计算量
    flops, params = get_model_complexity_info(model=self_attn, input_res=(1024, 768))
    print(flops, params)


if __name__ == '__main__':
    test()

程序执行结果:

BertSelfAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
--------------------------------------------------
BertSelfAttention(
  1.77 M, 100.000% Params, 905.97 MMac, 100.000% MACs, 
  (query): Linear(590.59 k, 33.333% Params, 301.99 MMac, 33.333% MACs, in_features=768, out_features=768, bias=True)
  (key): Linear(590.59 k, 33.333% Params, 301.99 MMac, 33.333% MACs, in_features=768, out_features=768, bias=True)
  (value): Linear(590.59 k, 33.333% Params, 301.99 MMac, 33.333% MACs, in_features=768, out_features=768, bias=True)
  (dropout): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.1, inplace=False)
)
905.97 MMac 1.77 M
--------------------------------------------------
BertSelfAttention(
  1.77 M, 100.000% Params, 1.81 GMac, 100.000% MACs, 
  (query): Linear(590.59 k, 33.333% Params, 603.98 MMac, 33.333% MACs, in_features=768, out_features=768, bias=True)
  (key): Linear(590.59 k, 33.333% Params, 603.98 MMac, 33.333% MACs, in_features=768, out_features=768, bias=True)
  (value): Linear(590.59 k, 33.333% Params, 603.98 MMac, 33.333% MACs, in_features=768, out_features=768, bias=True)
  (dropout): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.1, inplace=False)
)
1.81 GMac 1.77 M

我们从结果可以看到,当单条样本输入长度为 512 时,所产生的计算量为 905 MMac,而当长度修改为 1024 时,所产生的计算量为 1.81 GMac。Mac 和 FLOPs 一样,在这里都是用于计量模型浮点数运算量的单位。Mac 表示 Multiply–Accumulate Operations,即:把一次浮点数的乘和加操作看做一个 Mac,可以简单将其理解为 1Mac 相当于 2 FLOPs. 而 MMac 相当于 100 万次的 Mac 操作,GMac 相当于 10 亿次的 Mac 操作。

输入长度增加了一倍,计算量差不多也增加了一倍,说明自注意力机制的计算量和输入长度是正比的关系。为了能够高效的处理长本文序列,Longformer 对自注意力机制的计算方法进行了改进。

Paper:https://arxiv.org/pdf/2004.05150.pdf

另外一点思考:原始的自注意力机制计算时,每个 Token 都会关注到输入的所有其他 Token。是否有必要全部关注也是一个待思考的问题。

1. Sliding Window Attention

Longformer 中提出了局部的自注意力机制,即:注意力计算时,只关注窗口内的 token,而不是所有的 token。这个很明显能够降低自注意力计算的复杂度。这种自注意力的计算方式,Paper 中给出的名字叫做:sliding window attention,基于滑动窗口的注意力计算。

这里的 window size 如何设置呢?Paper 中给出一句话:

Depending on the application, it might be helpful to use different values of w for each layer to balance between efficiency and model representation capacity.

简单来说:就是不同的编码器层使用不同的 window size,例如:我们可以从一个较小的 window size 逐渐增加到一个较大的 window size,相当于逐渐增加感受范围。

2. Dilated Sliding Window Attention

为了在不增加计算量的情况下,表征 Token 时使用更大关注范围,Paper 中使用这种扩张的滑动窗口,如下图所示:

从图中可以快速理解到,原来扩张不过就是跳跃关注 token,这么一看确实范围变大了。Paper 针对这一情况,又给出了一段话:

In multi-headed attention, each attention head computes a different attention score. We found set- tings with different dilation configurations per head improves performance by allowing some heads without dilation to focus on local context, while others with dilation focus on longer context.

简单来讲,就是:我们不是有多头注意力吗?好,不同的 head 使用不同的扩张配置,比如:第一个头每隔 2 个 token 进行关注,第二个头每隔 8 个 token 进行关注,第三个头每隔 6 个 token 进行关注 … 以此类推,这样的方式能够提高模型的性能。

3. Global Attention + Sliding Window

到这里似乎就介绍完了 Longformer 的局部注意力机制。但是,突然发现了一个问题。我们原来在使用 Bert 时,会使用 [CLS] 来表征输入的整个序列,如果对 [CLS] 也使用局部注意力机制的计算方法,发现 [CLS] 可能并不能很好的表征输入的序列。所以,对于这样的特殊 token,我们仍然使用全局的注意力机制,也就是原始的关注所有的 token 的注意力计算方法。

这样的话:[CLS] 使用全局注意力,关注所有 token,其他的 token 除了关注窗口内的 token 之外,也得关注 [CLS] token。

we make this attention operation symmetric: that is, a token with a global attention attends to all tokens across the sequence, and all tokens in the sequence attend to it. 

注意:上面我们只是提到了 [CLS] 这个特殊的 token 使用全局注意力,当然,在我们的输入中可能还有一些其他位置也需要使用全局注意力,我们可以预先标记一些位置来使用全局注意力。比如:对于分类问题,只对 [CLS] 使用全局注意力即可,但对于 QA 问题,我们可能得需要对整个 question 的 token 使用全局注意力。

未经允许不得转载:一亩三分地 » Longformer
评论 (0)

3 + 6 =