Bert 模型对输入有 512 的长度限制,有时我们的输入会超过 512,此时就需要一些方法来解决,这里总结了一些方法。
- 修改模型的长度限制
- 对输入进行长度截断
- 通过滑动窗口重构输入
- 通过提取关键部分重构输入
- 使用支持长文本的模型
1. 修改模型的长度限制
如果我们从零开始训练一个输入长度超过 512 的模型,那么只需要在构建模型时,修改 BertConfig 配置就可以了,如下代码所示:
from transformers import BertModel from transformers import BertConfig if __name__ == '__main__': config = BertConfig() # 修改输入长度 config.max_position_embeddings = 1024 model = BertModel(config=config) print(config)
但是,如果在预训练模型的基础上修改文本输入长度,也可以通过下面的的方法进行修改。使用下面的方法进行微调时,损失值总是不会下降,原因尚未深入琢磨,这个问题在有答案是再补充。下面的方案权且是从编码角度实现。
from transformers import BertModel from transformers import BertConfig if __name__ == '__main__': config = BertConfig() config.max_position_embeddings = 1024 config.vocab_size = tokenizer.vocab_size # 读取 bert-base-chinese 的模型参数 state_dict = torch.load('bert-base-chinese/pytorch_model.bin') for key in state_dict.keys(): if key == 'bert.embeddings.position_embeddings.weight': if config.max_position_embeddings > 512: # 重新构建更大的位置编码矩阵 embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) # 将预训练模型中 512 位置编码拷贝到新的位置编码矩阵中 embedding.weight.data[:512, :] = state_dict[key].data[:512, :] state_dict[key].data = embedding.weight.data # 使用修改后的参数来构建模型 model = BertForSequenceClassification.from_pretrained(None, state_dict=state_dict, config=config) model = model.to(device)
2. 对输入进行长度截断
这个方法思路比较简单,输入长度超过 512 时,我们可以将其截断到满足输入长度即可。此时,可以截取前 512(长度也要去考虑特殊 token) 长度,或后 512 长度。
3. 通过滑动窗口重构输入
我们可以通过滑动窗口的方式,把输入重构成多个满足长度限制的多条样本,然后送入模型进行训练。
4. 通过提取关键部分重构输入
https://proceedings.neurips.cc/paper/2020/file/96671501524948bc3937b4b30d0e57b9-Paper.pdf
通过一些方法识别长文本中的关键句子、然后排序、抽取、融合、形成新的文本再送入到模型中进行训练。
5. 使用支持长文本的模型
可以使用 XLNet 这种支持长文本输入的模型。