GPT-2(Generative Pre-trained Transformer 2)是 OpenAI 开发的一种基于 Transformer 结构的自回归语言模型。它以无监督学习的方式在大规模文本数据上进行训练,能够生成连贯、富有逻辑性的文本,并广泛应用于文本生成、续写、翻译和摘要等任务。接下来我们利用 GPT-2 中文预训练模型来进行文本摘要任务。 OpenAI 发布的预训练模型是基于英文语料训练的,我找到了一个中文的预训练模型:
https://huggingface.co/IDEA-CCNL/Wenzhong2.0-GPT2-110M-BertTokenizer-chinese
1. 数据处理
LCSTS(Large-scale Chinese Short Text Summarization Dataset)是一个大规模的中文短文本摘要数据集,由哈尔滨工业大学(HIT)于 2015 年发布,主要用于研究 中文文本摘要(Text Summarization) 任务。
下面我们就使用该数据集进行文本摘要模型的训练和评估。首先,我们需要对该数据集进行处理,以便于适合模型的训练和评估。下面是数据示例:
原文 | 摘要 |
一辆小轿车,一名女司机,竟造成9死24伤。日前,深圳市交警局对事故进行通报:从目前证据看,事故系司机超速行驶且操作不当导致。目前24名伤员已有6名治愈出院,其余正接受治疗,预计事故赔偿费或超一千万元。 | 深圳机场9死24伤续:司机全责赔偿或超千万 |
训练数据使用分词器预先构建好,适合模型训练,格式:
原文[SEP]摘要[END] [108, 2315, 6328, 879, 56, 76 ...]
验证集和测试集数据格式:
[(原文,摘要), (原文,摘要) ...]
import pickle from transformers import BertTokenizer from tqdm import tqdm import joblib import glob import os tokenizer = BertTokenizer.from_pretrained('Wenzhong2.0-GPT2-110M-BertTokenizer-chinese') tokenizer.add_special_tokens({'eos_token': '[EOS]'}) def get_train_copurs(): srcs = open('LCSTS/train.src.txt').readlines() tgts = open('LCSTS/train.tgt.txt').readlines() # 选择部分数据 # select_number = 100000 # srcs = srcs[:select_number] # tgts = tgts[:select_number] sample_size, block_count = len(srcs), 10 block_size = sample_size // block_count print('总数据量:', sample_size, '线程数量:', block_count, '线程数据量:', block_size) @joblib.delayed def task(s, e, i): local_srcs, local_tgts = srcs[s: e], tgts[s: e] samples = [] progress = tqdm(range(len(local_srcs)), desc=f'train 数据处理') for src, tgt in zip(local_srcs, local_tgts): progress.update() src, tgt = src.strip(), tgt.strip() sample = src + '[SEP]' + tgt + '[EOS]' sample = tokenizer.tokenize(sample) if sample.count('[UNK]') > 0: continue sample = tokenizer.encode(sample, add_special_tokens=False) samples.append(sample) progress.close() pickle.dump(samples, open(f'data/train-{i}.pkl', 'wb')) tasks = [] for i in range(block_count): s = i * block_size e = (i + 1) * block_size tasks.append(task(s, e, i)) joblib.Parallel(n_jobs=block_count)(tasks) def get_other_corpus(data_type='test'): srcs = open(f'LCSTS/{data_type}.src.txt').readlines() tgts = open(f'LCSTS/{data_type}.tgt.txt').readlines() samples = [] for src, tgt in zip(srcs, tgts): src, tgt = src.strip(), tgt.strip() sample = tokenizer.tokenize(src + tgt) if sample.count('[UNK]') > 0: continue samples.append((src, tgt)) pickle.dump(samples, open(f'data/summary-{data_type}.pkl', 'wb')) print(f'{data_type} {len(samples)} 数据处理完毕!') print(f'{data_type}:', samples[0]) def demo01(): get_train_copurs() get_other_corpus('test') get_other_corpus('valid') def demo02(): fnames = glob.glob('data/train-*.pkl') print(fnames) samples = [] for fname in fnames: batch_sample = pickle.load(open(fname, 'rb')) samples.extend(batch_sample) # os.remove(fname) save_path = f'data/summary-train.pkl' pickle.dump(samples, open(save_path, 'wb')) print(f'train {len(samples)} 数据处理完毕!') print(f'train:', samples[0]) if __name__ == '__main__': # demo01() demo02()
demo01 函数 总数据量: 2400591 线程数量: 10 线程数据量: 240059 train 数据处理: 100%|█████████████████| 240059/240059 [01:39<00:00, 2408.24it/s] train 数据处理: 100%|█████████████████| 240059/240059 [01:41<00:00, 2365.45it/s] train 数据处理: 100%|█████████████████| 240059/240059 [01:46<00:00, 2244.64it/s] train 数据处理: 100%|█████████████████| 240059/240059 [01:42<00:00, 2334.02it/s] train 数据处理: 100%|█████████████████| 240059/240059 [01:38<00:00, 2444.17it/s] train 数据处理: 100%|█████████████████| 240059/240059 [01:42<00:00, 2342.27it/s] train 数据处理: 100%|█████████████████| 240059/240059 [01:57<00:00, 2046.21it/s] train 数据处理: 100%|█████████████████| 240059/240059 [02:14<00:00, 1784.23it/s] train 数据处理: 100%|█████████████████| 240059/240059 [02:16<00:00, 1763.48it/s] train 数据处理: 100%|█████████████████| 240059/240059 [02:20<00:00, 1707.82it/s] test 10090 数据处理完毕! test: ('日前,方舟子发文直指林志颖旗下爱碧丽推销假保健品,引起哗然。调查发现,爱碧丽没有自己的生产加工厂。其胶原蛋白饮品无核心研发,全部代工生产。号称有“逆生长”功效的爱碧丽“梦幻奇迹限量组”售价高达1080元,实际成本仅为每瓶4元!', '林志颖公司疑涉虚假营销无厂房无研发') valid 1062 数据处理完毕! valid: ('本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代人', '可穿戴技术十大设计原则') demo02 函数 ['data/train-4.pkl', 'data/train-5.pkl', 'data/train-3.pkl', 'data/train-7.pkl', 'data/train-0.pkl', 'data/train-8.pkl', 'data/train-2.pkl', 'data/train-6.pkl', 'data/train-9.pkl', 'data/train-1.pkl'] train 2297021 数据处理完毕! train: [1266, 776, 2356, 6823, 3817, 1921, 1765, 4852, 1277, 8024, 1914, 3406, 3517, 2791, 7674, 2231, 1469, 7553, 2231, 4900, 3022, 744, 2456, 698, 7028, 8024, 3300, 689, 712, 1762, 3517, 7553, 7463, 1378, 1217, 4667, 3517, 2231, 8024, 1146, 2768, 7392, 7313, 2190, 1912, 1139, 4909, 511, 6824, 2456, 2347, 698, 7028, 2512, 1510, 749, 1071, 800, 857, 2787, 4638, 3633, 2382, 4495, 3833, 511, 8213, 2399, 2207, 1277, 4289, 689, 3295, 1403, 3791, 7368, 6629, 6401, 6824, 2456, 689, 712, 8024, 1400, 1352, 3175, 6809, 2768, 1469, 6237, 1291, 6379, 8024, 852, 2400, 3313, 2533, 1168, 3300, 3126, 2809, 6121, 138, 9463, 140, 1266, 776, 6823, 3817, 1921, 1765, 2100, 1762, 1920, 7030, 3517, 7553, 3022, 2456, 3791, 7368, 1161, 749, 2809, 6121, 7410, 21135]
2. 模型训练
训练代码使用 transformers 的 Trainer 来完成,训练信息:
- learning_rate=5e-5
- per_device_train_batch_size=24
- gradient_accumulation_steps=8
- num_train_epochs=5
我们选择 30 万样本训练了 7 个小时。
import pickle import torch from transformers import GPT2LMHeadModel from transformers import BertTokenizer from transformers import Trainer from transformers import TrainingArguments from transformers import DataCollatorForLanguageModeling from transformers import TrainerCallback from transformers import TrainerState, TrainerControl device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EvalCallBack(TrainerCallback): def __init__(self, estimator, tokenizer): self.estimator = estimator self.tokenizer = tokenizer def on_epoch_end(self, args, state, control, **kwargs): src = '周其凤2012年在长沙一中发表演讲时批评美国教育失败引起争议。再谈及此次风波,周其凤说:“我们现在有一部分国民很可悲,可以骂自己的娘,但却不可以骂美国,我其实也不是骂,我也说了很多美国教育的好。' inputs = self.tokenizer.encode(src + '[SEP]', max_length=512, truncation=True, add_special_tokens=False) input_length = len(inputs) inputs = {'input_ids': torch.tensor([inputs]).to(device)} output = self.estimator.generate(**inputs, do_sample=False, max_length=512, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id) output = self.tokenizer.decode(output.squeeze()[input_length:], skip_special_tokens=False, top_p=0.95) print() print('生成文本:', ''.join(output.split())) # watch -n 1 nvidia-smi def train(): modelpath = 'Wenzhong2.0-GPT2-110M-BertTokenizer-chinese' estimator = GPT2LMHeadModel.from_pretrained(modelpath).to(device) tokenizer = BertTokenizer.from_pretrained(modelpath) special_tokens = {'eos_token': '[EOS]'} tokenizer.add_special_tokens(special_tokens) # 增加新的 token 之后,扩充 embedding,否则无法识别 estimator.resize_token_embeddings(len(tokenizer)) train_args = TrainingArguments( output_dir='summary', do_eval=False, save_strategy='epoch', save_total_limit=5, logging_strategy='no', learning_rate=5e-5, optim='adamw_torch', per_device_train_batch_size=24, gradient_accumulation_steps=8, num_train_epochs=5, disable_tqdm=False, ) train_data = pickle.load(open('data/summary-train.pkl', 'rb'))[:300000] trainer = Trainer(model=estimator, args=train_args, train_dataset=train_data, data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)) trainer.add_callback(EvalCallBack(estimator=estimator, tokenizer=tokenizer)) trainer.train() if __name__ == '__main__': train()
20%|███████ | 1562/7810 [1:24:18<5:35:45, 3.22s/it] 生成文本: 周其凤:我不是骂娘,我也说了很多美国教育的好[EOS] 40%|██████████████ | 3125/7810 [2:48:41<4:13:27, 3.25s/it] 生成文本: 周其凤:我们有一部分国民很可悲[EOS] 60%|█████████████████████ | 4687/7810 [4:13:01<2:48:09, 3.23s/it] 生成文本: 周其凤:我不是骂美国,我也说了很多美国教育的好[EOS] 80%|████████████████████████████ | 6250/7810 [5:37:25<1:24:12, 3.24s/it] 生成文本: 周其凤:我不是骂,我也说了很多美国教育的好[EOS] 100%|█████████████████████████████████████| 7810/7810 [7:01:39<00:00, 3.23s/it] 生成文本: 周其凤:我们现在有一部分国民很可悲[EOS] {'train_runtime': 25302.1607, 'train_samples_per_second': 59.283, 'train_steps_per_second': 0.309, 'train_loss': 2.5490039312580026, 'epoch': 5.0} 100%|█████████████████████████████████████| 7810/7810 [7:01:42<00:00, 3.24s/it]
3. 模型推理
随机从验证集中选择一条原文,由模型生成摘要,人工对比下生成质量。
import random from transformers import GPT2LMHeadModel from transformers import BertTokenizer import torch import pickle def inference(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') estimator = GPT2LMHeadModel.from_pretrained('summary/checkpoint-6250').to(device) tokenizer = BertTokenizer.from_pretrained('Wenzhong2.0-GPT2-110M-BertTokenizer-chinese') special_tokens = {'eos_token': '[EOS]'} tokenizer.add_special_tokens(special_tokens) valid_data = pickle.load(open('data/summary-valid.pkl', 'rb')) while True: message = input('请输入:') if message == 'exit': break index = random.randint(0, len(valid_data)) src, tgt = valid_data[index] print('输入文本:', src) print('参考文本:', tgt) inputs = tokenizer.encode(src + '[SEP]', max_length=512, truncation=True, add_special_tokens=False) input_length = len(inputs) inputs = {'input_ids': torch.tensor([inputs]).to(device)} output = estimator.generate(**inputs, do_sample=False, max_length=512, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) output = output.squeeze()[input_length:] output = tokenizer.decode(output, skip_special_tokens=False, do_sample=False, top_p=0.95) print('生成文本:', ''.join(output.split())) if __name__ == '__main__': inference()
输入文本: 博鳌亚洲论坛2014年年会11日落下帷幕。中国人民银行行长周小川说,央行历来强调货币政策调控主要包括通货膨胀率、经济增长率、新增就业、国际收支平衡状况四个指标,根据当前情况,央行可以把促进就业作为重要参考目标,但控制通货膨胀更为重要。 参考文本: 周小川:货币政策侧重控制通胀 生成文本: 周小川:货币政策调控主要包括通货膨胀率[EOS] 输入文本: 昨晚,中联航空成都飞北京一架航班被发现有多人吸烟。后因天气原因,飞机备降太原机场。几名乘客在舱门边吸烟被发现。有乘客要求重新安检,机长决定继续飞行,引起机组人员与未吸烟乘客冲突。目前中联航空正联系机组进行核实。 参考文本: 成都飞北京航班多人吸烟机组人员与未吸烟乘客冲突 生成文本: 中联航空成都飞北京航班多人吸烟[EOS] 输入文本: “山西官场被查蒙了。”8月29日晚,山西省政府一位官员对记者说。当天,毫无征兆的情况下,中纪委先后宣布山西省委常委、统战部长白云、山西省副省长任润厚涉嫌严重违纪违法被查。从年初至今,山西已有7名省部级官员被查。 参考文本: 山西省政府官员:山西官场被查蒙了 生成文本: 山西7名省部级官员被查[EOS] 输入文本: 据悉,国家大基金落地后,各地政府将陆续跟进。目前北京市集成电路促进基金已就位,规模约300亿。上海市基金也将落地,规模或有500亿。天津、厦门、合肥、成都、西安等地方政府也在筹备基金政策,各地方政府资金规模都在数十亿间。(21CBH) 参考文本: 各地跟进国家集成电路大基金 生成文本: 国家大基金落地后将陆续跟进[EOS] 输入文本: 中国车市自2011年开始告别一路狂飙,进入平稳增长阶段。今年一季度,国内车市表现疲软,出现多年未见的同比环比双下降。但步入稳步增长阶段的中国车市凭借巨大商机和未来前景依然吸引了众多跨国汽车公司和国内汽车企业踊跃参展 参考文本: 一季度车市疲软北京国际车展火爆依然 生成文本: 中国车市:一季度车市表现疲软[EOS]