BERT 长度限制

Bert 模型对输入有 512 的长度限制,有时我们的输入会超过 512,此时就需要一些方法来解决,这里总结了一些方法。

  1. 修改模型的长度限制
  2. 对输入进行长度截断
  3. 通过滑动窗口重构输入
  4. 通过提取关键部分重构输入
  5. 使用支持长文本的模型

1. 修改模型的长度限制

如果我们从零开始训练一个输入长度超过 512 的模型,那么只需要在构建模型时,修改 BertConfig 配置就可以了,如下代码所示:

from transformers import BertModel
from transformers import BertConfig


if __name__ == '__main__':

    config = BertConfig()
    # 修改输入长度
    config.max_position_embeddings = 1024
    model = BertModel(config=config)
    print(config)

但是,如果在预训练模型的基础上修改文本输入长度,也可以通过下面的的方法进行修改。使用下面的方法进行微调时,损失值总是不会下降,原因尚未深入琢磨,这个问题在有答案是再补充。下面的方案权且是从编码角度实现。

from transformers import BertModel
from transformers import BertConfig


if __name__ == '__main__':

    config = BertConfig()
    config.max_position_embeddings = 1024
    config.vocab_size = tokenizer.vocab_size
    # 读取 bert-base-chinese 的模型参数
    state_dict = torch.load('bert-base-chinese/pytorch_model.bin')
    for key in state_dict.keys():
        if key == 'bert.embeddings.position_embeddings.weight':
            if config.max_position_embeddings > 512:
                # 重新构建更大的位置编码矩阵
                embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
                # 将预训练模型中 512 位置编码拷贝到新的位置编码矩阵中
                embedding.weight.data[:512, :] = state_dict[key].data[:512, :]
                state_dict[key].data = embedding.weight.data
    # 使用修改后的参数来构建模型
    model = BertForSequenceClassification.from_pretrained(None, state_dict=state_dict, config=config)
    model = model.to(device)

2. 对输入进行长度截断

这个方法思路比较简单,输入长度超过 512 时,我们可以将其截断到满足输入长度即可。此时,可以截取前 512(长度也要去考虑特殊 token) 长度,或后 512 长度。

3. 通过滑动窗口重构输入

我们可以通过滑动窗口的方式,把输入重构成多个满足长度限制的多条样本,然后送入模型进行训练。

4. 通过提取关键部分重构输入

https://proceedings.neurips.cc/paper/2020/file/96671501524948bc3937b4b30d0e57b9-Paper.pdf

通过一些方法识别长文本中的关键句子、然后排序、抽取、融合、形成新的文本再送入到模型中进行训练。

5. 使用支持长文本的模型

可以使用 XLNet 这种支持长文本输入的模型。

未经允许不得转载:一亩三分地 » BERT 长度限制