我们接下来编写训练函数、评估函数、预测函数。
1. 训练函数
由于我们希望批次输入训练数据,在使用 RNN、GRU、LSTM 时,可以使用 pad_sequence、packed_pad_sequence、pad_packed_sequence 等函数组合来实现批次数据的输入。在将输入进行编码时,使用 BertTokenizer,为了能够准确的按字分开,输入给 BertTokenizer 的句子最好用空格隔开,便于其能够正确的切分出字粒度的 Token。
def pad_batch_inputs(data, labels, tokenizer): # 函数需要返回一个按照内容长度从大到小排序过的,sentence 和 label, 还要返回 sentence 长度 # 将批次数据的输入和标签值分开,并计算批次的输入长度 data_inputs, data_length, data_labels = [], [], [] for data_input, data_label in zip(data, labels): # 对输入句子进行编码 data_input_encode = tokenizer.encode(data_input, return_tensors='pt', add_special_tokens=False) data_input_encode = data_input_encode.to(device) data_inputs.append(data_input_encode.squeeze()) # 去除多余空格,计算句子长度 data_input = ''.join(data_input.split()) data_length.append(len(data_input)) # 将标签转换为张量 data_labels.append(torch.tensor(data_label, device=device)) # 对一个批次的内容按照长度从大到小排序, 符号表示降序 sorted_index = np.argsort(-np.asarray(data_length)) # 根据长度的索引进行排序 sorted_inputs, sorted_labels, sorted_length = [], [], [] for index in sorted_index: sorted_inputs.append(data_inputs[index]) sorted_labels.append(data_labels[index]) sorted_length.append(data_length[index]) # 对张量进行填充,使其变成长度一样的张量 pad_inputs = pad_sequence(sorted_inputs) return pad_inputs, sorted_labels, sorted_length def train(): # 读取数据集 train_data = load_from_disk('data/bilstm_crf_data')['train'] # 构建分词器 tokenizer = BertTokenizer(vocab_file='data/bilstm_crf_vocab.txt') # 构建模型 model = NER(vocab_size=tokenizer.vocab_size, label_num=7).cuda(device) # 批次大小 batch_size = 16 # 优化器 optimizer = optim.AdamW(model.parameters(), lr=3e-5) # 训练轮数 num_epoch = 50 # 开始训练 def start_train(data_inputs, data_labels, tokenizer): # 对批量数据进行填充对齐 pad_inputs, sorted_labels, sorted_length = \ pad_batch_inputs(data_inputs, data_labels, tokenizer) # 计算损失 loss = model(pad_inputs, sorted_labels, sorted_length) # 梯度清零 optimizer.zero_grad() # 反向传播 loss.backward() # 参数更新 optimizer.step() # 统计损失 nonlocal total_loss total_loss += loss.item() for epoch in range(num_epoch): # 统计损失 total_loss = 0.0 # 开始训练 train_data.map(start_train, input_columns=['data_inputs', 'data_labels'], batched=True, batch_size=batch_size, fn_kwargs={'tokenizer': tokenizer}, desc='epoch: %d' % (epoch + 1)) # 打印损失 print('epoch: %d loss: %.3f' % (epoch + 1, total_loss)) # 存储模型 model.save_model('data/BiLSTM-CRF-%d.bin' % (epoch + 1)) if __name__ == '__main__': train()
训练过程输入损失如下:
epoch: 1 loss: 646530.528 epoch: 2 loss: 322091.170 epoch: 3 loss: 231602.128 epoch: 4 loss: 182503.388 epoch: 5 loss: 149661.446 epoch: 6 loss: 125265.489 epoch: 7 loss: 106000.216 epoch: 8 loss: 90156.823 epoch: 9 loss: 76753.042 epoch: 10 loss: 65195.731 epoch: 11 loss: 55057.508 epoch: 12 loss: 46142.997 epoch: 13 loss: 38262.882 epoch: 14 loss: 31367.848 epoch: 15 loss: 25804.164 epoch: 16 loss: 23088.218 epoch: 17 loss: 18572.793 epoch: 18 loss: 15087.369 epoch: 19 loss: 13349.071 epoch: 20 loss: 11036.982 epoch: 21 loss: 9877.440 epoch: 22 loss: 8125.003 epoch: 23 loss: 6185.219 epoch: 24 loss: 5038.318 epoch: 25 loss: 4788.188 epoch: 26 loss: 3986.904 epoch: 27 loss: 3301.328 epoch: 28 loss: 2872.219 epoch: 29 loss: 2691.445 epoch: 30 loss: 3137.029 epoch: 31 loss: 1619.925 epoch: 32 loss: 2354.775 epoch: 33 loss: 2398.700 epoch: 34 loss: 1662.177 epoch: 35 loss: 1534.003 epoch: 36 loss: 1720.913 epoch: 37 loss: 1415.495 epoch: 38 loss: 1402.394 epoch: 39 loss: 1028.855 epoch: 40 loss: 1089.962 epoch: 41 loss: 1016.576 epoch: 42 loss: 1140.854 epoch: 43 loss: 1103.749 epoch: 44 loss: 735.662 epoch: 45 loss: 926.108 epoch: 46 loss: 1021.045 epoch: 47 loss: 1039.586 epoch: 48 loss: 608.818 epoch: 49 loss: 846.596 epoch: 50 loss: 685.285
最终得到的所有模型文件链接为:https://www.aliyundrive.com/s/HV1EoeYvKQ3 提取码: su42
2. 评估函数
我们这里就对 data/BiLSTM-CRF-50.bin 模型进行评估,当然你可以使用其他的模型的来评估。评估主要的步骤:
- 计算下测试集中各个不同类别的实体名称,并将这些名称分别存储到不同列表中;
- 将测试集送入模型得到预测,并存储不同类别的实体名称;
- 计算不同类别的实体的精度、召回率,以及整个测试集的准确率。
def evaluate(): # 读取测试数据 valid_data = load_from_disk('data/bilstm_crf_data')['valid'] # 1. 计算各个不同类别总实体数量 # 计算测试集实体数量 total_entities = {'ORG': [], 'PER': [], 'LOC': []} def calculate_handler(data_inputs, data_labels): # 将 data_inputs 转换为没有空格隔开的句子 data_inputs = ''.join(data_inputs.split()) # 提取句子中的实体 extract_entities = extract_decode(data_labels, data_inputs) # 统计每种实体的数量 nonlocal total_entities for key, value in extract_entities.items(): total_entities[key].extend(value) # 统计不同实体的数量 valid_data.map(calculate_handler, input_columns=['data_inputs', 'data_labels']) print(total_entities) # 2. 计算模型预测的各个类别实体数量 model_param = torch.load('data/BiLSTM-CRF-50.bin') model = NER(**model_param['init']).cuda(device) model.load_state_dict(model_param['state']) # 构建分词器 tokenizer = BertTokenizer(vocab_file='data/bilstm_crf_vocab.txt') model_entities = {'ORG': [], 'PER': [], 'LOC': []} def start_evaluate(data_inputs): # 对输入文本进行分词 model_inputs = tokenizer.encode(data_inputs, add_special_tokens=False, return_tensors='pt')[0] model_inputs = model_inputs.to(device) # 文本送入模型进行计算 with torch.no_grad(): label_list = model.predict(model_inputs) # 统计预测的实体数量 text = ''.join(data_inputs.split()) # 从预测结果提取实体名字 extract_entities = extract_decode(label_list, text) nonlocal model_entities for key, value in extract_entities.items(): model_entities[key].extend(value) # 统计预测不同实体的数量 valid_data.map(start_evaluate, input_columns=['data_inputs'], batched=False) print(model_entities) # 3. 统计每个类别的召回率 total_pred_correct = 0 total_true_correct = 0 for key in total_entities.keys(): # 获得当前 key 类别真实和模型预测实体列表 true_entities = total_entities[key] true_entities_num = len(true_entities) pred_entities = model_entities[key] # 分解预测实体中,pred_correct 表示预测正确,pred_incorrect 表示预测错误 pred_correct, pred_incorrect = 0, 0 for pred_entity in pred_entities: if pred_entity in true_entities: pred_correct += 1 continue pred_incorrect += 1 # 模型预测的 key 类别的实体数量 model_pred_key_num = true_entities_num + pred_incorrect # 计算共预测正确多少个实体 total_pred_correct += pred_correct # 计算共有多少个真实的实体 total_true_correct += true_entities_num # 计算精度 print(key, '查全率: %.3f' % (pred_correct / true_entities_num)) print(key, '查准率: %.3f' % (pred_correct / model_pred_key_num)) print('-' * 50) print('准确率: %.3f' % (total_pred_correct / total_true_correct)) def extract_decode(label_list, text): """ :param label_list: 模型输出的包含标签序列的一维列表 :param text: 模型输入的句子 :return: 提取到的实体名字 """ labels = ['O', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC'] label_to_index = {label: index for index, label in enumerate(labels)} B_ORG, I_ORG = label_to_index['B-ORG'], label_to_index['I-ORG'] B_PER, I_PER = label_to_index['B-PER'], label_to_index['I-PER'] B_LOC, I_LOC = label_to_index['B-LOC'], label_to_index['I-LOC'] # 提取连续的标签代表的实体 def extract_word(start_index, next_label): # index 表示最后索引的位置 index, entity = start_index + 1, [text[start_index]] for index in range(start_index + 1, len(label_list)): if label_list[index] != next_label: break entity.append(text[index]) return index, ''.join(entity) # 存储提取的命名实体 extract_entites, index = {'ORG': [], 'PER': [], 'LOC': []}, 0 # 映射下一个持续的标签 next_label = {B_ORG: I_ORG, B_PER: I_PER, B_LOC: I_LOC} # 映射词的所属类别 word_class = {B_ORG: 'ORG', B_PER: 'PER', B_LOC: 'LOC'} while index < len(label_list): # 获得当前位置的标签 label = label_list[index] if label in next_label.keys(): # 将当前位置和对应的下一个持续标签传递到 extract_word 函数 index, word = extract_word(index, next_label[label]) extract_entites[word_class[label]].append(word) continue index += 1 return extract_entites if __name__ == '__main__': evaluate()
程序输出结果:
ORG 查全率: 0.762 ORG 查准率: 0.655 -------------------------------------------------- PER 查全率: 0.824 PER 查准率: 0.716 -------------------------------------------------- LOC 查全率: 0.846 LOC 查准率: 0.778 -------------------------------------------------- 准确率: 0.821
3. 预测函数
预测函数就是输入一个句子,提取句子中的实体。步骤如下:
- 输入的句子先添加上空格隔开;
- BertTokenizer 进行编码;
- 使用 model 进行预测;
- 使用维特比解码解析出实体名称。
def entity_extract(text): # 构建分词器 tokenizer = BertTokenizer(vocab_file='data/bilstm_crf_vocab.txt') # 初始化模型 model_param = torch.load('data/BiLSTM-CRF-48.bin') model = NER(**model_param['init']).cuda(device) model.load_state_dict(model_param['state']) # 我们先按字将其分开,并在字之间添加空格,便于 Bert 分词器能够准确按字分割 input_text = ' '.join(list(text)) model_inputs = tokenizer.encode(input_text, add_special_tokens=False, return_tensors='pt')[0] model_inputs = model_inputs.to(device) with torch.no_grad(): outputs = model.predict(model_inputs) return extract_decode(outputs, ''.join(input_text.split())) if __name__ == '__main__': text = '我要感谢洛杉矶市民议政论坛、亚洲协会南加中心、美中关系全国委员会、美中友协美西分会等友好团体的盛情款待。' result = entity_extract(text) print(result)
程序输出结果:
{'ORG': ['亚洲协会南加中心', '美中关系全国委员会', '美中友协美西分会'], 'PER': [], 'LOC': ['洛杉矶']}