我们接下来编写训练函数、评估函数、预测函数。
1. 训练函数
由于我们希望批次输入训练数据,在使用 RNN、GRU、LSTM 时,可以使用 pad_sequence、packed_pad_sequence、pad_packed_sequence 等函数组合来实现批次数据的输入。在将输入进行编码时,使用 BertTokenizer,为了能够准确的按字分开,输入给 BertTokenizer 的句子最好用空格隔开,便于其能够正确的切分出字粒度的 Token。
def pad_batch_inputs(data, labels, tokenizer):
# 函数需要返回一个按照内容长度从大到小排序过的,sentence 和 label, 还要返回 sentence 长度
# 将批次数据的输入和标签值分开,并计算批次的输入长度
data_inputs, data_length, data_labels = [], [], []
for data_input, data_label in zip(data, labels):
# 对输入句子进行编码
data_input_encode = tokenizer.encode(data_input,
return_tensors='pt',
add_special_tokens=False)
data_input_encode = data_input_encode.to(device)
data_inputs.append(data_input_encode.squeeze())
# 去除多余空格,计算句子长度
data_input = ''.join(data_input.split())
data_length.append(len(data_input))
# 将标签转换为张量
data_labels.append(torch.tensor(data_label, device=device))
# 对一个批次的内容按照长度从大到小排序, 符号表示降序
sorted_index = np.argsort(-np.asarray(data_length))
# 根据长度的索引进行排序
sorted_inputs, sorted_labels, sorted_length = [], [], []
for index in sorted_index:
sorted_inputs.append(data_inputs[index])
sorted_labels.append(data_labels[index])
sorted_length.append(data_length[index])
# 对张量进行填充,使其变成长度一样的张量
pad_inputs = pad_sequence(sorted_inputs)
return pad_inputs, sorted_labels, sorted_length
def train():
# 读取数据集
train_data = load_from_disk('data/bilstm_crf_data')['train']
# 构建分词器
tokenizer = BertTokenizer(vocab_file='data/bilstm_crf_vocab.txt')
# 构建模型
model = NER(vocab_size=tokenizer.vocab_size, label_num=7).cuda(device)
# 批次大小
batch_size = 16
# 优化器
optimizer = optim.AdamW(model.parameters(), lr=3e-5)
# 训练轮数
num_epoch = 50
# 开始训练
def start_train(data_inputs, data_labels, tokenizer):
# 对批量数据进行填充对齐
pad_inputs, sorted_labels, sorted_length = \
pad_batch_inputs(data_inputs, data_labels, tokenizer)
# 计算损失
loss = model(pad_inputs, sorted_labels, sorted_length)
# 梯度清零
optimizer.zero_grad()
# 反向传播
loss.backward()
# 参数更新
optimizer.step()
# 统计损失
nonlocal total_loss
total_loss += loss.item()
for epoch in range(num_epoch):
# 统计损失
total_loss = 0.0
# 开始训练
train_data.map(start_train,
input_columns=['data_inputs', 'data_labels'],
batched=True,
batch_size=batch_size,
fn_kwargs={'tokenizer': tokenizer},
desc='epoch: %d' % (epoch + 1))
# 打印损失
print('epoch: %d loss: %.3f' % (epoch + 1, total_loss))
# 存储模型
model.save_model('data/BiLSTM-CRF-%d.bin' % (epoch + 1))
if __name__ == '__main__':
train()
训练过程输入损失如下:
epoch: 1 loss: 646530.528 epoch: 2 loss: 322091.170 epoch: 3 loss: 231602.128 epoch: 4 loss: 182503.388 epoch: 5 loss: 149661.446 epoch: 6 loss: 125265.489 epoch: 7 loss: 106000.216 epoch: 8 loss: 90156.823 epoch: 9 loss: 76753.042 epoch: 10 loss: 65195.731 epoch: 11 loss: 55057.508 epoch: 12 loss: 46142.997 epoch: 13 loss: 38262.882 epoch: 14 loss: 31367.848 epoch: 15 loss: 25804.164 epoch: 16 loss: 23088.218 epoch: 17 loss: 18572.793 epoch: 18 loss: 15087.369 epoch: 19 loss: 13349.071 epoch: 20 loss: 11036.982 epoch: 21 loss: 9877.440 epoch: 22 loss: 8125.003 epoch: 23 loss: 6185.219 epoch: 24 loss: 5038.318 epoch: 25 loss: 4788.188 epoch: 26 loss: 3986.904 epoch: 27 loss: 3301.328 epoch: 28 loss: 2872.219 epoch: 29 loss: 2691.445 epoch: 30 loss: 3137.029 epoch: 31 loss: 1619.925 epoch: 32 loss: 2354.775 epoch: 33 loss: 2398.700 epoch: 34 loss: 1662.177 epoch: 35 loss: 1534.003 epoch: 36 loss: 1720.913 epoch: 37 loss: 1415.495 epoch: 38 loss: 1402.394 epoch: 39 loss: 1028.855 epoch: 40 loss: 1089.962 epoch: 41 loss: 1016.576 epoch: 42 loss: 1140.854 epoch: 43 loss: 1103.749 epoch: 44 loss: 735.662 epoch: 45 loss: 926.108 epoch: 46 loss: 1021.045 epoch: 47 loss: 1039.586 epoch: 48 loss: 608.818 epoch: 49 loss: 846.596 epoch: 50 loss: 685.285
最终得到的所有模型文件链接为:https://www.aliyundrive.com/s/HV1EoeYvKQ3 提取码: su42
2. 评估函数
我们这里就对 data/BiLSTM-CRF-50.bin 模型进行评估,当然你可以使用其他的模型的来评估。评估主要的步骤:
- 计算下测试集中各个不同类别的实体名称,并将这些名称分别存储到不同列表中;
- 将测试集送入模型得到预测,并存储不同类别的实体名称;
- 计算不同类别的实体的精度、召回率,以及整个测试集的准确率。
def evaluate():
# 读取测试数据
valid_data = load_from_disk('data/bilstm_crf_data')['valid']
# 1. 计算各个不同类别总实体数量
# 计算测试集实体数量
total_entities = {'ORG': [], 'PER': [], 'LOC': []}
def calculate_handler(data_inputs, data_labels):
# 将 data_inputs 转换为没有空格隔开的句子
data_inputs = ''.join(data_inputs.split())
# 提取句子中的实体
extract_entities = extract_decode(data_labels, data_inputs)
# 统计每种实体的数量
nonlocal total_entities
for key, value in extract_entities.items():
total_entities[key].extend(value)
# 统计不同实体的数量
valid_data.map(calculate_handler, input_columns=['data_inputs', 'data_labels'])
print(total_entities)
# 2. 计算模型预测的各个类别实体数量
model_param = torch.load('data/BiLSTM-CRF-50.bin')
model = NER(**model_param['init']).cuda(device)
model.load_state_dict(model_param['state'])
# 构建分词器
tokenizer = BertTokenizer(vocab_file='data/bilstm_crf_vocab.txt')
model_entities = {'ORG': [], 'PER': [], 'LOC': []}
def start_evaluate(data_inputs):
# 对输入文本进行分词
model_inputs = tokenizer.encode(data_inputs, add_special_tokens=False, return_tensors='pt')[0]
model_inputs = model_inputs.to(device)
# 文本送入模型进行计算
with torch.no_grad():
label_list = model.predict(model_inputs)
# 统计预测的实体数量
text = ''.join(data_inputs.split())
# 从预测结果提取实体名字
extract_entities = extract_decode(label_list, text)
nonlocal model_entities
for key, value in extract_entities.items():
model_entities[key].extend(value)
# 统计预测不同实体的数量
valid_data.map(start_evaluate, input_columns=['data_inputs'], batched=False)
print(model_entities)
# 3. 统计每个类别的召回率
total_pred_correct = 0
total_true_correct = 0
for key in total_entities.keys():
# 获得当前 key 类别真实和模型预测实体列表
true_entities = total_entities[key]
true_entities_num = len(true_entities)
pred_entities = model_entities[key]
# 分解预测实体中,pred_correct 表示预测正确,pred_incorrect 表示预测错误
pred_correct, pred_incorrect = 0, 0
for pred_entity in pred_entities:
if pred_entity in true_entities:
pred_correct += 1
continue
pred_incorrect += 1
# 模型预测的 key 类别的实体数量
model_pred_key_num = true_entities_num + pred_incorrect
# 计算共预测正确多少个实体
total_pred_correct += pred_correct
# 计算共有多少个真实的实体
total_true_correct += true_entities_num
# 计算精度
print(key, '查全率: %.3f' % (pred_correct / true_entities_num))
print(key, '查准率: %.3f' % (pred_correct / model_pred_key_num))
print('-' * 50)
print('准确率: %.3f' % (total_pred_correct / total_true_correct))
def extract_decode(label_list, text):
"""
:param label_list: 模型输出的包含标签序列的一维列表
:param text: 模型输入的句子
:return: 提取到的实体名字
"""
labels = ['O', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC']
label_to_index = {label: index for index, label in enumerate(labels)}
B_ORG, I_ORG = label_to_index['B-ORG'], label_to_index['I-ORG']
B_PER, I_PER = label_to_index['B-PER'], label_to_index['I-PER']
B_LOC, I_LOC = label_to_index['B-LOC'], label_to_index['I-LOC']
# 提取连续的标签代表的实体
def extract_word(start_index, next_label):
# index 表示最后索引的位置
index, entity = start_index + 1, [text[start_index]]
for index in range(start_index + 1, len(label_list)):
if label_list[index] != next_label:
break
entity.append(text[index])
return index, ''.join(entity)
# 存储提取的命名实体
extract_entites, index = {'ORG': [], 'PER': [], 'LOC': []}, 0
# 映射下一个持续的标签
next_label = {B_ORG: I_ORG, B_PER: I_PER, B_LOC: I_LOC}
# 映射词的所属类别
word_class = {B_ORG: 'ORG', B_PER: 'PER', B_LOC: 'LOC'}
while index < len(label_list):
# 获得当前位置的标签
label = label_list[index]
if label in next_label.keys():
# 将当前位置和对应的下一个持续标签传递到 extract_word 函数
index, word = extract_word(index, next_label[label])
extract_entites[word_class[label]].append(word)
continue
index += 1
return extract_entites
if __name__ == '__main__':
evaluate()
程序输出结果:
ORG 查全率: 0.762 ORG 查准率: 0.655 -------------------------------------------------- PER 查全率: 0.824 PER 查准率: 0.716 -------------------------------------------------- LOC 查全率: 0.846 LOC 查准率: 0.778 -------------------------------------------------- 准确率: 0.821
3. 预测函数
预测函数就是输入一个句子,提取句子中的实体。步骤如下:
- 输入的句子先添加上空格隔开;
- BertTokenizer 进行编码;
- 使用 model 进行预测;
- 使用维特比解码解析出实体名称。
def entity_extract(text):
# 构建分词器
tokenizer = BertTokenizer(vocab_file='data/bilstm_crf_vocab.txt')
# 初始化模型
model_param = torch.load('data/BiLSTM-CRF-48.bin')
model = NER(**model_param['init']).cuda(device)
model.load_state_dict(model_param['state'])
# 我们先按字将其分开,并在字之间添加空格,便于 Bert 分词器能够准确按字分割
input_text = ' '.join(list(text))
model_inputs = tokenizer.encode(input_text, add_special_tokens=False, return_tensors='pt')[0]
model_inputs = model_inputs.to(device)
with torch.no_grad():
outputs = model.predict(model_inputs)
return extract_decode(outputs, ''.join(input_text.split()))
if __name__ == '__main__':
text = '我要感谢洛杉矶市民议政论坛、亚洲协会南加中心、美中关系全国委员会、美中友协美西分会等友好团体的盛情款待。'
result = entity_extract(text)
print(result)
程序输出结果:
{'ORG': ['亚洲协会南加中心', '美中关系全国委员会', '美中友协美西分会'], 'PER': [], 'LOC': ['洛杉矶']}

冀公网安备13050302001966号