基于 Bert+CRF 实现 NER 任务

Bert + CRF 的技术组合是目前非常常见的解决 NER 问题的架构,由 Bert 的多头自注意力机制,让模型建立起 Token 之间的关系,并给出第一阶段的 BIO 标签预测,由于在这一阶段中,我们并没有要求模型去学习标签序列之间的约束关系,例如:I_PERSON 不能作为序列的第一个标签、I_PERSON 后面不能是 I_LOC 等等,通过添加这部分的约束学习,能够使得 NER 的结果更加鲁棒。CRF 就是用来从训练样本中去学习标签之前约束的技术。

首先给出了 CRF 的实现代码,这部分代码的实现思路,在之前的文章中,我有实现。这里给出的实现是 fastNLP 库中给出的实现(感谢作者),代码如下:

import torch.nn as nn
import torch
from typing import Union, List, Tuple


class ConditionalRandomField(nn.Module):
    r"""
    条件随机场。提供 :meth:`forward` 以及 :meth:`viterbi_decode` 两个方法,分别用于 **训练** 与 **inference** 。

    :param num_tags: 标签的数量
    :param include_start_end_trans: 是否考虑各个 tag 作为开始以及结尾的分数。
    :param allowed_transitions: 内部的 ``Tuple[from_tag_id(int), to_tag_id(int)]`` 视为允许发生的跃迁,其他没
        有包含的跃迁认为是禁止跃迁,可以通过 :func:`allowed_transitions` 函数得到;如果为 ``None`` ,则所有跃迁均为合法。
    """

    def __init__(self, num_tags:int, include_start_end_trans:bool=False, allowed_transitions:List=None):
        super(ConditionalRandomField, self).__init__()

        self.include_start_end_trans = include_start_end_trans
        self.num_tags = num_tags

        # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score
        self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags))
        if self.include_start_end_trans:
            self.start_scores = nn.Parameter(torch.randn(num_tags))
            self.end_scores = nn.Parameter(torch.randn(num_tags))

        if allowed_transitions is None:
            constrain = torch.zeros(num_tags + 2, num_tags + 2)
        else:
            constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float)
            has_start = False
            has_end = False
            for from_tag_id, to_tag_id in allowed_transitions:
                constrain[from_tag_id, to_tag_id] = 0
                if from_tag_id==num_tags:
                    has_start = True
                if to_tag_id==num_tags+1:
                    has_end = True
            if not has_start:
                constrain[num_tags, :].fill_(0)
            if not has_end:
                constrain[:, num_tags+1].fill_(0)
        self._constrain = nn.Parameter(constrain, requires_grad=False)

    def _normalizer_likelihood(self, logits, mask):
        r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the
        sum of the likelihoods across all possible state sequences.

        :param logits:FloatTensor, ``[max_len, batch_size, num_tags]``
        :param mask:ByteTensor, ``[max_len, batch_size]``
        :return:FloatTensor, ``[batch_size,]``
        """
        seq_len, batch_size, n_tags = logits.size()
        alpha = logits[0]
        if self.include_start_end_trans:
            alpha = alpha + self.start_scores.view(1, -1)

        flip_mask = mask.eq(False)

        for i in range(1, seq_len):
            emit_score = logits[i].view(batch_size, 1, n_tags)
            trans_score = self.trans_m.view(1, n_tags, n_tags)
            tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
            alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
                    alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0)

        if self.include_start_end_trans:
            alpha = alpha + self.end_scores.view(1, -1)

        return torch.logsumexp(alpha, 1)

    def _gold_score(self, logits, tags, mask):
        r"""
        Compute the score for the gold path.
        :param logits: FloatTensor, ``[max_len, batch_size, num_tags]``
        :param tags: LongTensor, ``[max_len, batch_size]``
        :param mask: ByteTensor, ``[max_len, batch_size]``
        :return:FloatTensor, ``[batch_size.]``
        """
        seq_len, batch_size, _ = logits.size()
        batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
        seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)

        # trans_socre [L-1, B]
        mask = mask.eq(True)
        flip_mask = mask.eq(False)
        trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
        # emit_score [L, B]
        emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0)
        # score [L-1, B]
        score = trans_score + emit_score[:seq_len - 1, :]
        score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0)
        if self.include_start_end_trans:
            st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
            last_idx = mask.long().sum(0) - 1
            ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
            score = score + st_scores + ed_scores
        # return [B,]
        return score

    def forward(self, feats: "torch.FloatTensor", tags: "torch.LongTensor", mask: "torch.ByteTensor") -> "torch.FloatTensor":
        r"""
        用于计算 ``CRF`` 的前向 loss,返回值为一个形状为 ``[batch_size,]`` 的 :class:`torch.FloatTensor` ,可能需要 :func:`mean` 求得 loss 。

        :param feats: 特征矩阵,形状为 ``[batch_size, max_len, num_tags]``
        :param tags: 标签矩阵,形状为 ``[batch_size, max_len]``
        :param mask: 形状为 ``[batch_size, max_len]`` ,为 **0** 的位置认为是 padding。
        :return: ``[batch_size,]``
        """
        feats = feats.transpose(0, 1)
        tags = tags.transpose(0, 1).long()
        mask = mask.transpose(0, 1).float()
        all_path_score = self._normalizer_likelihood(feats, mask)
        gold_path_score = self._gold_score(feats, tags, mask)

        return all_path_score - gold_path_score

    def viterbi_decode(self, logits: "torch.FloatTensor", mask: "torch.ByteTensor", unpad=False):
        r"""给定一个 **特征矩阵** 以及 **转移分数矩阵** ,计算出最佳的路径以及对应的分数

        :param logits: 特征矩阵,形状为 ``[batch_size, max_len, num_tags]``
        :param mask: 标签矩阵,形状为 ``[batch_size, max_len]`` ,为 **0** 的位置认为是 padding。如果为 ``None`` ,则认为没有 padding。
        :param unpad: 是否将结果删去 padding:

                - 为 ``False`` 时,返回的是 ``[batch_size, max_len]`` 的张量
                - 为 ``True`` 时,返回的是 :class:`List` [:class:`List` [ :class:`int` ]], 内部的 :class:`List` [:class:`int` ] 为每个
                  sequence 的 label ,已经除去 pad 部分,即每个 :class:`List` [ :class:`int` ] 的长度是这个 sample 的有效长度。
        :return:  (paths, scores)。

                - ``paths`` -- 解码后的路径, 其值参照 ``unpad`` 参数.
                - ``scores`` -- :class:`torch.FloatTensor` ,形状为 ``[batch_size,]`` ,对应每个最优路径的分数。

        """
        batch_size, max_len, n_tags = logits.size()
        seq_len = mask.long().sum(1)
        logits = logits.transpose(0, 1).data  # L, B, H
        mask = mask.transpose(0, 1).data.eq(True)  # L, B
        flip_mask = mask.eq(False)

        # dp
        vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long)
        vscore = logits[0]  # bsz x n_tags
        transitions = self._constrain.data.clone()
        transitions[:n_tags, :n_tags] += self.trans_m.data
        if self.include_start_end_trans:
            transitions[n_tags, :n_tags] += self.start_scores.data
            transitions[:n_tags, n_tags + 1] += self.end_scores.data

        vscore += transitions[n_tags, :n_tags]

        trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
        end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags

        # 针对长度为1的句子
        vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \
            .masked_fill(seq_len.ne(1).view(-1, 1), 0)
        for i in range(1, max_len):
            prev_score = vscore.view(batch_size, n_tags, 1)
            cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score
            score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0)  # bsz x n_tag x n_tag
            # 需要考虑当前位置是该序列的最后一个
            score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0)

            best_score, best_dst = score.max(1)
            vpath[i] = best_dst
            # 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况
            vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
                     vscore.masked_fill(mask[i].view(batch_size, 1), 0)

        # backtrace
        batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
        seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device)
        lens = (seq_len - 1)
        # idxes [L, B], batched idx from seq_len-1 to 0
        idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len

        ans = logits.new_empty((max_len, batch_size), dtype=torch.long)
        ans_score, last_tags = vscore.max(1)
        ans[idxes[0], batch_idx] = last_tags
        for i in range(max_len - 1):
            last_tags = vpath[idxes[i], batch_idx, last_tags]
            ans[idxes[i + 1], batch_idx] = last_tags
        ans = ans.transpose(0, 1)
        if unpad:
            paths = []
            for idx, max_len in enumerate(lens):
                paths.append(ans[idx, :max_len + 1].tolist())
        else:
            paths = ans
        return paths, ans_score

训练代码如下:

import numpy
from transformers import BertTokenizer
from transformers import BertForTokenClassification
from datasets import load_from_disk
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from ignite.engine import Engine
from functools import partial
from torch.nn.utils.rnn import pad_sequence
from ignite.engine import Events
from tqdm import tqdm
from ignite.metrics import Accuracy
import shutil
import random
import heapq

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_step(engine, batch_data):
    data_inputs, data_labels = batch_data
    loss = engine.estimator(**data_inputs, labels=data_labels)
    engine.optimizer.zero_grad()
    loss.backward()
    engine.optimizer.step()

    return {'sample_loss': loss.item() * len(data_inputs)}


def collate_function(tokenizer, batch_data):
    data_inputs, data_labels = [], []
    for data in batch_data:
        data_inputs.append(data['sentence'])
        data_labels.append(torch.tensor(data['label'], device=device))
    data_inputs = tokenizer(data_inputs, padding='longest', add_special_tokens=False, return_token_type_ids=False, return_tensors='pt')
    data_inputs = {key: value.to(device) for key, value in data_inputs.items()}
    data_labels = pad_sequence(data_labels, batch_first=True, padding_value=-100)
    return data_inputs, data_labels

def on_epoch_started(engine):
    engine.estimator.train()
    engine.total_sample_losses = 0
    engine.progress = tqdm(range(engine.iter_nums), ncols=110)

def on_epoch_completed(engine):
    engine.progress.close()
    do_evaluation(engine)

def on_iteration_completed(engine):
    max_epoch = engine.state.max_epochs
    cur_epoch = engine.state.epoch
    engine.total_sample_losses += engine.state.output['sample_loss']

    desc = f'training epoch {cur_epoch:2d}/{max_epoch:2d} loss {engine.total_sample_losses:12.4f}'
    engine.progress.set_description(desc)
    engine.progress.update()

@torch.no_grad()
def do_evaluation(engine):
    engine.estimator.eval()
    progress = tqdm(range(len(engine.testloader)), ncols=110)
    max_epoch = engine.state.max_epochs
    cur_epoch = engine.state.epoch
    total_sample_losses = 0
    for data_input, data_labels in engine.testloader:
        loss = engine.estimator(**data_input, labels=data_labels)
        total_sample_losses += loss.item() * len(data_input)
        desc = f'evaluate epoch {cur_epoch:2d}/{max_epoch:2d} loss {total_sample_losses:12.4f}'
        progress.set_description(desc)
        progress.update()
    progress.close()

    loss = round(total_sample_losses, 4)
    checkpoint = f'finish/bert-crf/epoch-{cur_epoch}-train-{engine.total_sample_losses:.4f}-test-{total_sample_losses:.4f}'
    # 存储损失最小的3个模型
    if len(engine.checkpoints) < 3:
        engine.checkpoints.append({'checkpoint': checkpoint, 'loss': loss})
        engine.estimator.save_pretrained(checkpoint)
        engine.tokenizer.save_pretrained(checkpoint)
        return

    # 如果当前 checkpoint 损失比最大还大,则不进行存储
    engine.checkpoints = sorted(engine.checkpoints, key=lambda x: x['loss'], reverse=True)
    if loss > engine.checkpoints[0]['loss']:
        return

    # 删除损失最大的模型
    shutil.rmtree(engine.checkpoints[0]['checkpoint'])
    engine.checkpoints.pop(0)
    engine.checkpoints.append({'checkpoint': checkpoint, 'loss': loss})
    engine.estimator.save_pretrained(checkpoint)
    engine.tokenizer.save_pretrained(checkpoint)


from ner_crf import ConditionalRandomField
class BertCRF(nn.Module):

    def __init__(self):
        super(BertCRF, self).__init__()
        self.model = BertForTokenClassification.from_pretrained('pretrained/bert-base-chinese', num_labels=7)
        self.crf = ConditionalRandomField(num_tags=7)

    def forward(self, *args, **kwargs):
        labels = kwargs.pop('labels')
        outputs = self.model(*args, **kwargs)
        mask = (labels != -100)
        labels[~mask] = 0
        loss = self.crf(feats=outputs.logits, tags=labels, mask=mask)
        return torch.mean(loss)


def do_train():

    checkpoint = 'pretrained/bert-base-chinese'
    estimator = BertCRF().to(device)
    tokenizer = BertTokenizer.from_pretrained(checkpoint)
    optimizer = optim.Adam(estimator.parameters(), lr=1e-5)
    traindata = load_from_disk('data/train_valid.data')
    testdata = load_from_disk('data/test.data')
    trainloader = DataLoader(traindata, batch_size=4, collate_fn=partial(collate_function, tokenizer))
    testloader = DataLoader(testdata, batch_size=16, collate_fn=partial(collate_function, tokenizer))

    trainer = Engine(train_step)
    trainer.estimator = estimator
    trainer.optimizer = optimizer
    trainer.tokenizer = tokenizer
    trainer.iter_nums = len(trainloader)
    trainer.testloader = testloader
    trainer.checkpoints = []

    trainer.add_event_handler(Events.EPOCH_STARTED, on_epoch_started)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, on_epoch_completed)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)

    trainer.run(trainloader, max_epochs=20)


if __name__ == '__main__':
    do_train()
未经允许不得转载:一亩三分地 » 基于 Bert+CRF 实现 NER 任务
评论 (0)

4 + 3 =