Bert + CRF 的技术组合是目前非常常见的解决 NER 问题的架构,由 Bert 的多头自注意力机制,让模型建立起 Token 之间的关系,并给出第一阶段的 BIO 标签预测,由于在这一阶段中,我们并没有要求模型去学习标签序列之间的约束关系,例如:I_PERSON 不能作为序列的第一个标签、I_PERSON 后面不能是 I_LOC 等等,通过添加这部分的约束学习,能够使得 NER 的结果更加鲁棒。CRF 就是用来从训练样本中去学习标签之前约束的技术。
首先给出了 CRF 的实现代码,这部分代码的实现思路,在之前的文章中,我有实现。这里给出的实现是 fastNLP 库中给出的实现(感谢作者),代码如下:
import torch.nn as nn import torch from typing import Union, List, Tuple class ConditionalRandomField(nn.Module): r""" 条件随机场。提供 :meth:`forward` 以及 :meth:`viterbi_decode` 两个方法,分别用于 **训练** 与 **inference** 。 :param num_tags: 标签的数量 :param include_start_end_trans: 是否考虑各个 tag 作为开始以及结尾的分数。 :param allowed_transitions: 内部的 ``Tuple[from_tag_id(int), to_tag_id(int)]`` 视为允许发生的跃迁,其他没 有包含的跃迁认为是禁止跃迁,可以通过 :func:`allowed_transitions` 函数得到;如果为 ``None`` ,则所有跃迁均为合法。 """ def __init__(self, num_tags:int, include_start_end_trans:bool=False, allowed_transitions:List=None): super(ConditionalRandomField, self).__init__() self.include_start_end_trans = include_start_end_trans self.num_tags = num_tags # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) if self.include_start_end_trans: self.start_scores = nn.Parameter(torch.randn(num_tags)) self.end_scores = nn.Parameter(torch.randn(num_tags)) if allowed_transitions is None: constrain = torch.zeros(num_tags + 2, num_tags + 2) else: constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float) has_start = False has_end = False for from_tag_id, to_tag_id in allowed_transitions: constrain[from_tag_id, to_tag_id] = 0 if from_tag_id==num_tags: has_start = True if to_tag_id==num_tags+1: has_end = True if not has_start: constrain[num_tags, :].fill_(0) if not has_end: constrain[:, num_tags+1].fill_(0) self._constrain = nn.Parameter(constrain, requires_grad=False) def _normalizer_likelihood(self, logits, mask): r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the sum of the likelihoods across all possible state sequences. :param logits:FloatTensor, ``[max_len, batch_size, num_tags]`` :param mask:ByteTensor, ``[max_len, batch_size]`` :return:FloatTensor, ``[batch_size,]`` """ seq_len, batch_size, n_tags = logits.size() alpha = logits[0] if self.include_start_end_trans: alpha = alpha + self.start_scores.view(1, -1) flip_mask = mask.eq(False) for i in range(1, seq_len): emit_score = logits[i].view(batch_size, 1, n_tags) trans_score = self.trans_m.view(1, n_tags, n_tags) tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0) if self.include_start_end_trans: alpha = alpha + self.end_scores.view(1, -1) return torch.logsumexp(alpha, 1) def _gold_score(self, logits, tags, mask): r""" Compute the score for the gold path. :param logits: FloatTensor, ``[max_len, batch_size, num_tags]`` :param tags: LongTensor, ``[max_len, batch_size]`` :param mask: ByteTensor, ``[max_len, batch_size]`` :return:FloatTensor, ``[batch_size.]`` """ seq_len, batch_size, _ = logits.size() batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) # trans_socre [L-1, B] mask = mask.eq(True) flip_mask = mask.eq(False) trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) # emit_score [L, B] emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) # score [L-1, B] score = trans_score + emit_score[:seq_len - 1, :] score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) if self.include_start_end_trans: st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] last_idx = mask.long().sum(0) - 1 ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] score = score + st_scores + ed_scores # return [B,] return score def forward(self, feats: "torch.FloatTensor", tags: "torch.LongTensor", mask: "torch.ByteTensor") -> "torch.FloatTensor": r""" 用于计算 ``CRF`` 的前向 loss,返回值为一个形状为 ``[batch_size,]`` 的 :class:`torch.FloatTensor` ,可能需要 :func:`mean` 求得 loss 。 :param feats: 特征矩阵,形状为 ``[batch_size, max_len, num_tags]`` :param tags: 标签矩阵,形状为 ``[batch_size, max_len]`` :param mask: 形状为 ``[batch_size, max_len]`` ,为 **0** 的位置认为是 padding。 :return: ``[batch_size,]`` """ feats = feats.transpose(0, 1) tags = tags.transpose(0, 1).long() mask = mask.transpose(0, 1).float() all_path_score = self._normalizer_likelihood(feats, mask) gold_path_score = self._gold_score(feats, tags, mask) return all_path_score - gold_path_score def viterbi_decode(self, logits: "torch.FloatTensor", mask: "torch.ByteTensor", unpad=False): r"""给定一个 **特征矩阵** 以及 **转移分数矩阵** ,计算出最佳的路径以及对应的分数 :param logits: 特征矩阵,形状为 ``[batch_size, max_len, num_tags]`` :param mask: 标签矩阵,形状为 ``[batch_size, max_len]`` ,为 **0** 的位置认为是 padding。如果为 ``None`` ,则认为没有 padding。 :param unpad: 是否将结果删去 padding: - 为 ``False`` 时,返回的是 ``[batch_size, max_len]`` 的张量 - 为 ``True`` 时,返回的是 :class:`List` [:class:`List` [ :class:`int` ]], 内部的 :class:`List` [:class:`int` ] 为每个 sequence 的 label ,已经除去 pad 部分,即每个 :class:`List` [ :class:`int` ] 的长度是这个 sample 的有效长度。 :return: (paths, scores)。 - ``paths`` -- 解码后的路径, 其值参照 ``unpad`` 参数. - ``scores`` -- :class:`torch.FloatTensor` ,形状为 ``[batch_size,]`` ,对应每个最优路径的分数。 """ batch_size, max_len, n_tags = logits.size() seq_len = mask.long().sum(1) logits = logits.transpose(0, 1).data # L, B, H mask = mask.transpose(0, 1).data.eq(True) # L, B flip_mask = mask.eq(False) # dp vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long) vscore = logits[0] # bsz x n_tags transitions = self._constrain.data.clone() transitions[:n_tags, :n_tags] += self.trans_m.data if self.include_start_end_trans: transitions[n_tags, :n_tags] += self.start_scores.data transitions[:n_tags, n_tags + 1] += self.end_scores.data vscore += transitions[n_tags, :n_tags] trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags # 针对长度为1的句子 vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \ .masked_fill(seq_len.ne(1).view(-1, 1), 0) for i in range(1, max_len): prev_score = vscore.view(batch_size, n_tags, 1) cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) # bsz x n_tag x n_tag # 需要考虑当前位置是该序列的最后一个 score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0) best_score, best_dst = score.max(1) vpath[i] = best_dst # 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况 vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ vscore.masked_fill(mask[i].view(batch_size, 1), 0) # backtrace batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device) lens = (seq_len - 1) # idxes [L, B], batched idx from seq_len-1 to 0 idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len ans = logits.new_empty((max_len, batch_size), dtype=torch.long) ans_score, last_tags = vscore.max(1) ans[idxes[0], batch_idx] = last_tags for i in range(max_len - 1): last_tags = vpath[idxes[i], batch_idx, last_tags] ans[idxes[i + 1], batch_idx] = last_tags ans = ans.transpose(0, 1) if unpad: paths = [] for idx, max_len in enumerate(lens): paths.append(ans[idx, :max_len + 1].tolist()) else: paths = ans return paths, ans_score
训练代码如下:
import numpy from transformers import BertTokenizer from transformers import BertForTokenClassification from datasets import load_from_disk import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from ignite.engine import Engine from functools import partial from torch.nn.utils.rnn import pad_sequence from ignite.engine import Events from tqdm import tqdm from ignite.metrics import Accuracy import shutil import random import heapq device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def train_step(engine, batch_data): data_inputs, data_labels = batch_data loss = engine.estimator(**data_inputs, labels=data_labels) engine.optimizer.zero_grad() loss.backward() engine.optimizer.step() return {'sample_loss': loss.item() * len(data_inputs)} def collate_function(tokenizer, batch_data): data_inputs, data_labels = [], [] for data in batch_data: data_inputs.append(data['sentence']) data_labels.append(torch.tensor(data['label'], device=device)) data_inputs = tokenizer(data_inputs, padding='longest', add_special_tokens=False, return_token_type_ids=False, return_tensors='pt') data_inputs = {key: value.to(device) for key, value in data_inputs.items()} data_labels = pad_sequence(data_labels, batch_first=True, padding_value=-100) return data_inputs, data_labels def on_epoch_started(engine): engine.estimator.train() engine.total_sample_losses = 0 engine.progress = tqdm(range(engine.iter_nums), ncols=110) def on_epoch_completed(engine): engine.progress.close() do_evaluation(engine) def on_iteration_completed(engine): max_epoch = engine.state.max_epochs cur_epoch = engine.state.epoch engine.total_sample_losses += engine.state.output['sample_loss'] desc = f'training epoch {cur_epoch:2d}/{max_epoch:2d} loss {engine.total_sample_losses:12.4f}' engine.progress.set_description(desc) engine.progress.update() @torch.no_grad() def do_evaluation(engine): engine.estimator.eval() progress = tqdm(range(len(engine.testloader)), ncols=110) max_epoch = engine.state.max_epochs cur_epoch = engine.state.epoch total_sample_losses = 0 for data_input, data_labels in engine.testloader: loss = engine.estimator(**data_input, labels=data_labels) total_sample_losses += loss.item() * len(data_input) desc = f'evaluate epoch {cur_epoch:2d}/{max_epoch:2d} loss {total_sample_losses:12.4f}' progress.set_description(desc) progress.update() progress.close() loss = round(total_sample_losses, 4) checkpoint = f'finish/bert-crf/epoch-{cur_epoch}-train-{engine.total_sample_losses:.4f}-test-{total_sample_losses:.4f}' # 存储损失最小的3个模型 if len(engine.checkpoints) < 3: engine.checkpoints.append({'checkpoint': checkpoint, 'loss': loss}) engine.estimator.save_pretrained(checkpoint) engine.tokenizer.save_pretrained(checkpoint) return # 如果当前 checkpoint 损失比最大还大,则不进行存储 engine.checkpoints = sorted(engine.checkpoints, key=lambda x: x['loss'], reverse=True) if loss > engine.checkpoints[0]['loss']: return # 删除损失最大的模型 shutil.rmtree(engine.checkpoints[0]['checkpoint']) engine.checkpoints.pop(0) engine.checkpoints.append({'checkpoint': checkpoint, 'loss': loss}) engine.estimator.save_pretrained(checkpoint) engine.tokenizer.save_pretrained(checkpoint) from ner_crf import ConditionalRandomField class BertCRF(nn.Module): def __init__(self): super(BertCRF, self).__init__() self.model = BertForTokenClassification.from_pretrained('pretrained/bert-base-chinese', num_labels=7) self.crf = ConditionalRandomField(num_tags=7) def forward(self, *args, **kwargs): labels = kwargs.pop('labels') outputs = self.model(*args, **kwargs) mask = (labels != -100) labels[~mask] = 0 loss = self.crf(feats=outputs.logits, tags=labels, mask=mask) return torch.mean(loss) def do_train(): checkpoint = 'pretrained/bert-base-chinese' estimator = BertCRF().to(device) tokenizer = BertTokenizer.from_pretrained(checkpoint) optimizer = optim.Adam(estimator.parameters(), lr=1e-5) traindata = load_from_disk('data/train_valid.data') testdata = load_from_disk('data/test.data') trainloader = DataLoader(traindata, batch_size=4, collate_fn=partial(collate_function, tokenizer)) testloader = DataLoader(testdata, batch_size=16, collate_fn=partial(collate_function, tokenizer)) trainer = Engine(train_step) trainer.estimator = estimator trainer.optimizer = optimizer trainer.tokenizer = tokenizer trainer.iter_nums = len(trainloader) trainer.testloader = testloader trainer.checkpoints = [] trainer.add_event_handler(Events.EPOCH_STARTED, on_epoch_started) trainer.add_event_handler(Events.EPOCH_COMPLETED, on_epoch_completed) trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed) trainer.run(trainloader, max_epochs=20) if __name__ == '__main__': do_train()