基于 PEGASUS 生成中文文本摘要

PEGASUS 是一种编码器-解码器模型,接下来我们基于开源的 PEGASUS 预训练模型来微调自己的生成式文本摘要模型。感谢 https://huggingface.co/IDEA-CCNL 给开源出来的预训练模型,案例还是在其基础上在不分的 LCSTS 数据集上进行的训练。关于 Pegasus 模型的理解,请看下面我写的文章:

import jieba
jieba.setLogLevel(0)

import torch
from transformers import PegasusForConditionalGeneration
from tokenizers_pegasus import PegasusTokenizer
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import copy
from torch.nn.utils.rnn import pad_sequence
import glob
from rouge import Rouge
import numpy as np
from torchtext.data.metrics import bleu_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

接下来,从构造模型输入、模型训练、模型评估、模型推理等几个点,来做下细节的说明以及代码实现。

1. 构造输入

我们使用的数据集为 LCSTS,其文本主要包括:

LCSTS
|-- test.src.txt
|-- test.tgt.txt
|-- train.src.txt
|-- train.tgt.txt
|-- valid.src.txt
|-- valid.tgt.txt

从文件名可以看到,数据集包括训练集、测试集、验证集。下面是训练集对应的文件内容:

train.src.txt

新华社受权于18日全文播发修改后的《中华人民共和国立法法》,修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、自治条例和单行条例、规章”“适用与备案审查”“附则”等6章,共计105条。
一辆小轿车,一名女司机,竟造成9死24伤。日前,深圳市交警局对事故进行通报:从目前证据看,事故系司机超速行驶且操作不当导致。目前24名伤员已有6名治愈出院,其余正接受治疗,预计事故赔偿费或超一千万元。
...

train.tgt.txt

修改后的立法法全文公布
深圳机场9死24伤续:司机全责赔偿或超千万
...

每一行的短文都对应了一行的摘要内容。由于 Pegasus 模型是 Encoder-Decoder 模型,训练时需要指定编码器输入,解码器输入。如何构造呢?

对于 transformers 的 Pegasus 模型来说,构造时只需要指定编码器输入和标签即可,代码内部会根据构造的标签来生成对应的解码器输入:

if labels is not None:
        if use_cache:
            logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
        use_cache = False
        if decoder_input_ids is None:
            decoder_input_ids = shift_tokens_right(
                labels, self.config.pad_token_id, self.config.decoder_start_token_id
            )


# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids

在这一部分,我们只需要先把数据读取处理,数据处理成输入部分在训练时再指定:

def data_read(type='train'):

    all_inputs, all_labels = [], []
    for inputs, labels in zip(open('LCSTS/' + type + '.src.txt'), open('LCSTS/' + type + '.tgt.txt')):
        all_inputs.append(inputs.strip())
        all_labels.append(labels.strip())

    return all_inputs, all_labels


class SummaryDataset:

    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels
        self.counts = len(self.labels)

    def __len__(self):
        return self.counts

    def __getitem__(self, index):
        return self.inputs[index], self.labels[index]

2. 训练函数

这一部分要注意,从 huggingface 下载的预训练模型中,我们选择的是作者经过微调之后的模型,相当于我们现在做的工作是继续训练。

另外,需要将下载的模型目录下的 data_utils.py、tokenizers_pegasus.py 拷贝到当前目录下,需要这么做的原因是作者在词表中加入了 jieba 分词,也就是说作者在词表中引入了很多 “词” 级别的 token,在对输入进行编码时,能够将一些有意义词作为一个 id,使用默认的 pegasus tokenizer 无法完成这一点。

补充一下,拷贝到当前目录下之后,运行时可能会报错,只需要根据错误提示,注释 data_utils 文件中关于 jieba 报错的部分即可。

原始的 LCSTS 训练集有 240万+ 数据,模型本身又是编码器+解码器的模型,训练速度会很慢,对显存要求也很高。为了加快训练速度,我这里使用了累加梯度训练的方式,并且数据集只用了其中的 8 万条。这样的话,一个 epoch 大概在我机器上是 33 分钟(Centos 7 + RTX2060 + 6G显存)。

完整训练代码如下:

def train():

    estimator = PegasusForConditionalGeneration.from_pretrained('Randeng-Pegasus-238M-Summary-Chinese').to(device)
    tokenizer = PegasusTokenizer.from_pretrained('Randeng-Pegasus-238M-Summary-Chinese')
    optimizer = optim.Adam(estimator.parameters(), lr=3e-5)
    criterion = nn.CrossEntropyLoss()

    train_inputs, train_labels = data_read(type='train')
    train_number = 80000
    train_inputs, train_labels = train_inputs[:train_number], train_labels[:train_number]
    train_data = SummaryDataset(train_inputs, train_labels)

    def collate_function(batch_data):

        encoder_inputs, decoder_inputs = [], []
        for einputs, dinputs in batch_data:
            encoder_inputs.append(einputs)
            decoder_inputs.append(dinputs)

        # 将摘要内容转换为目标值
        summary_labels = tokenizer.batch_encode_plus(decoder_inputs, return_attention_mask=False)
        summary_labels = [torch.tensor(input_ids, device=device) for input_ids in summary_labels['input_ids']]
        summary_labels = pad_sequence(summary_labels, padding_value=criterion.ignore_index, batch_first=True)

        # 将文本输入转换索引表示
        encoder_inputs = tokenizer(encoder_inputs, max_length=1024, truncation=True, padding='longest', return_tensors="pt")
        encoder_inputs = { key: value.to(device) for key, value in encoder_inputs.items()}
        encoder_inputs['labels'] = summary_labels

        return encoder_inputs

    dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collate_function)

    for epoch in range(10):

        estimator.train()
        progress = tqdm(range(len(dataloader)))
        total_loss = 0.0
        acumulate_start = 1

        optimizer.zero_grad()
        for encoder_inputs in dataloader:
            outputs = estimator(**encoder_inputs)

            (outputs.loss/32).backward()

            acumulate_start += 1
            if acumulate_start % 32 == 0 or acumulate_start == len(dataloader):
                optimizer.step()
                optimizer.zero_grad()

            total_loss += outputs.loss.item()
            progress.set_description('epoch %d loss %8.2f' % (epoch + 1, total_loss))
            progress.update()
        progress.close()

        # 模型存储
        checkpoint = 'model/%d-pegasus-summary-%.2f' % (epoch + 1, total_loss)
        estimator.save_pretrained(checkpoint)
        tokenizer.save_pretrained(checkpoint)
        # 模型评估
        train_evaluate(estimator, tokenizer)
        print('-' * 100)

训练过程的输出:

epoch 1 loss 59105.36: 100%|██████████████| 20000/20000 [31:08<00:00, 10.70it/s]
100%|█████████████████████████████████████████| 553/553 [01:40<00:00,  5.48it/s]
Rouge-1: 29.16
Rouge-2: 14.5
Rouge-L: 27.19
BLEU-4: 0.1161
----------------------------------------------------------------------------------------------------
epoch 2 loss 53440.37: 100%|██████████████| 20000/20000 [31:07<00:00, 10.71it/s]
100%|█████████████████████████████████████████| 553/553 [01:40<00:00,  5.50it/s]
Rouge-1: 29.4
Rouge-2: 14.85
Rouge-L: 27.54
BLEU-4: 0.1185
----------------------------------------------------------------------------------------------------
epoch 3 loss 49058.87: 100%|██████████████| 20000/20000 [31:08<00:00, 10.70it/s]
100%|█████████████████████████████████████████| 553/553 [01:41<00:00,  5.42it/s]
Rouge-1: 29.93
Rouge-2: 14.97
Rouge-L: 27.9
BLEU-4: 0.1191
----------------------------------------------------------------------------------------------------
epoch 4 loss 45132.87: 100%|██████████████| 20000/20000 [31:08<00:00, 10.71it/s]
100%|█████████████████████████████████████████| 553/553 [01:39<00:00,  5.58it/s]
Rouge-1: 30.22
Rouge-2: 15.0
Rouge-L: 28.22
BLEU-4: 0.1192
----------------------------------------------------------------------------------------------------
epoch 5 loss 41556.33: 100%|██████████████| 20000/20000 [31:09<00:00, 10.70it/s]
100%|█████████████████████████████████████████| 553/553 [01:38<00:00,  5.60it/s]
Rouge-1: 30.12
Rouge-2: 15.12
Rouge-L: 28.06
BLEU-4: 0.1192
----------------------------------------------------------------------------------------------------
epoch 6 loss 38239.26: 100%|██████████████| 20000/20000 [31:08<00:00, 10.70it/s]
100%|█████████████████████████████████████████| 553/553 [01:37<00:00,  5.68it/s]
Rouge-1: 30.01
Rouge-2: 14.86
Rouge-L: 28.14
BLEU-4: 0.1159
----------------------------------------------------------------------------------------------------
epoch 7 loss 35088.80: 100%|██████████████| 20000/20000 [31:09<00:00, 10.70it/s]
100%|█████████████████████████████████████████| 553/553 [01:37<00:00,  5.67it/s]
Rouge-1: 30.57
Rouge-2: 15.54
Rouge-L: 28.66
BLEU-4: 0.1232
----------------------------------------------------------------------------------------------------
epoch 8 loss 32208.71: 100%|██████████████| 20000/20000 [31:08<00:00, 10.70it/s]
100%|█████████████████████████████████████████| 553/553 [01:36<00:00,  5.71it/s]
Rouge-1: 30.08
Rouge-2: 15.37
Rouge-L: 28.23
BLEU-4: 0.1232
----------------------------------------------------------------------------------------------------
epoch 9 loss 29451.46: 100%|██████████████| 20000/20000 [31:02<00:00, 10.74it/s]
100%|█████████████████████████████████████████| 553/553 [01:34<00:00,  5.87it/s]
Rouge-1: 30.22
Rouge-2: 15.18
Rouge-L: 28.34
BLEU-4: 0.1183
----------------------------------------------------------------------------------------------------
epoch 10 loss 26907.70: 100%|█████████████| 20000/20000 [31:02<00:00, 10.74it/s]
100%|█████████████████████████████████████████| 553/553 [01:34<00:00,  5.83it/s]
Rouge-1: 30.58
Rouge-2: 15.64
Rouge-L: 28.6
BLEU-4: 0.1229
----------------------------------------------------------------------------------------------------

3. 评估函数

评估主要对生成摘要进行了 Rouge-N、Rouge-L 和 BLEU-4 评价,详细代码如下:

@torch.no_grad()
def train_evaluate(estimator):
    estimator.eval()

    tokenizer = PegasusTokenizer.from_pretrained('Randeng-Pegasus-238M-Summary-Chinese')

    def collate_function(batch_data):
        encoder_inputs, decoder_inputs = [], []
        for einputs, dinputs in batch_data:
            encoder_inputs.append(einputs)
            decoder_inputs.append(dinputs)

        # 将文本输入转换索引表示
        encoder_inputs = tokenizer(encoder_inputs, max_length=1024, truncation=True, padding='longest', return_tensors="pt")
        # 将张量移动到 device 设备上
        encoder_inputs = {key: value.to(device) for key, value in encoder_inputs.items()}
        return encoder_inputs, decoder_inputs

    # 评估数据集
    valid_inputs, valid_labels = data_read(type='valid')
    # valid_inputs, valid_labels = valid_inputs[:100], valid_labels[:100]
    valid_data = SummaryDataset(valid_inputs, valid_labels)
    dataloader = DataLoader(valid_data, shuffle=False, batch_size=2, collate_fn=collate_function)

    true_labels, pred_labels = [], []
    proress = tqdm(range(len(dataloader)))
    for encoder_inputs, decoder_inputs in dataloader:
        outputs = estimator.generate(encoder_inputs['input_ids'], max_length=512, num_beams=3, do_sample=False)
        pred_label = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        true_label = decoder_inputs
        true_labels.extend(true_label)
        pred_labels.extend(pred_label)
        proress.update()
    proress.close()

    true_labels = [' '.join(tokenizer.tokenize(true_label)) for true_label in true_labels]
    pred_labels = [' '.join(tokenizer.tokenize(pred_label)) for pred_label in pred_labels]

    # 计算评估分数
    rouge = Rouge(metrics=['rouge-1', 'rouge-2', 'rouge-l'], stats=['f'])
    rouge_scores = rouge.get_scores(pred_labels, true_labels)

    rouge_1 = [score['rouge-1']['f'] for score in rouge_scores]
    rouge_2 = [score['rouge-2']['f'] for score in rouge_scores]
    rouge_L = [score['rouge-l']['f'] for score in rouge_scores]

    # 将 rouge score 表示成百分比的形式,所以分数乘以 100
    print('Rouge-1:', np.round(np.mean(rouge_1) * 100, decimals=2))
    print('Rouge-2:', np.round(np.mean(rouge_2) * 100, decimals=2))
    print('Rouge-L:', np.round(np.mean(rouge_L) * 100, decimals=2))

    # 计算 BLEU 值
    # bleu_score 函数要求我们输入的摘要以词列表的形式
    # 候选摘要: [['候', '选', '摘', '要', '1'], ['候', '选', '摘', '要', '2']]
    # 参考摘要: [[['参考摘要1-1', '参考摘要1-2']], [['参考摘要2-1']]]
    # 一个候选摘要可以对应多个参考摘要
    true_labels = [[tokenizer.tokenize(true_label)] for true_label in true_labels]
    pred_labels = [tokenizer.tokenize(pred_label)   for pred_label in pred_labels]

    bleu = bleu_score(candidate_corpus=pred_labels, references_corpus=true_labels)
    print('BLEU-4:', np.round(bleu, decimals=4))


# 从文件加载模型
@torch.no_grad()
def evaluate(checkpoint):

    estimator = PegasusForConditionalGeneration.from_pretrained(checkpoint).to(device)
    estimator.eval()
    tokenizer = PegasusTokenizer.from_pretrained('Randeng-Pegasus-238M-Summary-Chinese')

    def collate_function(batch_data):
        encoder_inputs, decoder_inputs = [], []
        for einputs, dinputs in batch_data:
            encoder_inputs.append(einputs)
            decoder_inputs.append(dinputs)

        # 将文本输入转换索引表示
        encoder_inputs = tokenizer(encoder_inputs, max_length=1024, truncation=True, padding='longest', return_tensors="pt")
        # 将张量移动到 device 设备上
        encoder_inputs = {key: value.to(device) for key, value in encoder_inputs.items()}
        return encoder_inputs, decoder_inputs

    # 评估数据集
    valid_inputs, valid_labels = data_read(type='test')
    # valid_inputs, valid_labels = valid_inputs[:100], valid_labels[:100]
    valid_data = SummaryDataset(valid_inputs, valid_labels)
    dataloader = DataLoader(valid_data, shuffle=False, batch_size=4, collate_fn=collate_function)

    true_labels, pred_labels = [], []
    proress = tqdm(range(len(dataloader)))
    for encoder_inputs, decoder_inputs in dataloader:
        outputs = estimator.generate(encoder_inputs['input_ids'], max_length=512, num_beams=3, do_sample=False)
        pred_label = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        true_label = decoder_inputs
        true_labels.extend(true_label)
        pred_labels.extend(pred_label)
        proress.update()
    proress.close()

    true_labels = [' '.join(tokenizer.tokenize(true_label)) for true_label in true_labels]
    pred_labels = [' '.join(tokenizer.tokenize(pred_label)) for pred_label in pred_labels]

    # 计算评估分数
    rouge = Rouge(metrics=['rouge-1', 'rouge-2', 'rouge-l'], stats=['f'])
    rouge_scores = rouge.get_scores(pred_labels, true_labels)

    rouge_1 = [score['rouge-1']['f'] for score in rouge_scores]
    rouge_2 = [score['rouge-2']['f'] for score in rouge_scores]
    rouge_L = [score['rouge-l']['f'] for score in rouge_scores]

    # 将 rouge score 表示成百分比的形式,所以分数乘以 100
    print('Rouge-1:', np.round(np.mean(rouge_1) * 100, decimals=2))
    print('Rouge-2:', np.round(np.mean(rouge_2) * 100, decimals=2))
    print('Rouge-L:', np.round(np.mean(rouge_L) * 100, decimals=2))

    # 计算 BLEU 值
    # bleu_score 函数要求我们输入的摘要以词列表的形式
    # 候选摘要: [['候', '选', '摘', '要', '1'], ['候', '选', '摘', '要', '2']]
    # 参考摘要: [[['参考摘要1-1', '参考摘要1-2']], [['参考摘要2-1']]]
    # 一个候选摘要可以对应多个参考摘要
    true_labels = [[tokenizer.tokenize(true_label)] for true_label in true_labels]
    pred_labels = [tokenizer.tokenize(pred_label)   for pred_label in pred_labels]

    bleu = bleu_score(candidate_corpus=pred_labels, references_corpus=true_labels)
    print('BLEU-4:', np.round(bleu, decimals=4))

验证集的评估结果:

100%|███████████████████████████████████████| 2667/2667 [09:04<00:00,  4.90it/s]
Rouge-1: 30.42
Rouge-2: 15.8
Rouge-L: 28.28
BLEU-4: 0.1286

4. 摘要生成

@torch.no_grad()
def predict():

    checkpoint = glob.glob('model/10-pegasus-summary-*')[0]
    print(checkpoint)

    estimator = PegasusForConditionalGeneration.from_pretrained(checkpoint)
    estimator.eval()

    # 加载 vocab.txt 词表时,修改后的 pegssus tokenizer 会进行如下的替换
    #         self.vocab[self.eos_token] = self.vocab.pop("[unused1]")
    #         # self.vocab[self.eos_token] = self.vocab.pop("[unused2]")
    #         self.vocab[self.pad_token] = self.vocab.pop("[PAD]")
    #         self.vocab[self.unk_token] = self.vocab.pop("[UNK]")
    # 但是存储模型时,如果直接 tokenizer.save_pretrained 之后,单词已经被替换过了
    # 所以,直接加载词表会报错,提示缺少 unused1 KeyError 错误
    # 解决方法:要不直接用模型的 vocab 进行替换,要不注释掉上面这部分代码
    # 我们这里就将 Randeng-Pegasus-238M-Summary-Chinese 中的 vocab.txt 替换掉 save_pretrained 之后的 vocab.txt
    tokenizer = PegasusTokenizer.from_pretrained('Randeng-Pegasus-238M-Summary-Chinese')

    text = '2007年乔布斯向人们展示iPhone并宣称“它将会改变世界”,还有人认为他在夸大其词,然而在8年后,以iPhone为代表的触屏智能手机已经席卷全球各个角落。未来,智能手机将会成为“真正的个人电脑”,为人类发展做出更大的贡献。'

    print(text)
    inputs = tokenizer([text], max_length=1024, truncation=True, return_attention_mask=False, return_tensors='pt')
    print(inputs)
    # num_beams=1 and do_sample=False 表示使用贪心解码
    # num_beams>1 and do_sample=False 表示使用束搜索解码
    y_pred = estimator.generate(inputs['input_ids'], max_length=512, num_beams=3, do_sample=False)
    print(tokenizer.batch_decode(y_pred, skip_special_tokens=True))

程序输出结果:

model/10-pegasus-summary-26907.70
2007年乔布斯向人们展示iPhone并宣称“它将会改变世界”,还有人认为他在夸大其词,然而在8年后,以iPhone为代表的触屏智能手机已经席卷全球各个角落。未来,智能手机将会成为“真正的个人电脑”,为人类发展做出更大的贡献。
{'input_ids': tensor([[43156,  1625, 36562,   810,  8311, 16475, 48063,  1626, 15849,   175,
          1397,  1463,   346, 20039,  7136,   176,  5661, 32012,   297, 30447,
           314,  1101,  1240,  1230,   527,  4628,  5661, 24631,  1101,   129,
          1625,   807,  5661,   323, 48063,   230,  8486,  3399,  4585,  1497,
          2305,  4054,  1934,  2355, 16764, 16900,  9715, 12721, 30342,   179,
         21415,  5661,  2305,  4054,  1934,  2355,  1463,   346, 18526,   175,
         26223,  3399, 43867,   176,  5661,   230,  8355, 12286,  9354,  2328,
          1230,  3399, 31045,   179,     1]])}
['“触屏”智能手机的未来之路']

未经允许不得转载:一亩三分地 » 基于 PEGASUS 生成中文文本摘要
评论 (0)

4 + 6 =