对联生成作为一种文本生成任务,通常要求生成内容具有对称性和韵律。过去,我们多采用基于预训练模型的微调方法来完成这一任务。这一次,我们尝试使用较小尺寸的 Llama 模型,从零开始进行训练,即:把 Llama 当做 GRU、LSTM 一样来使用。
1. 构建词表
import pickle import re from tqdm import tqdm def demo01(): is_unicode = lambda text: any('\ue000' <= char <= '\uf8ff' for char in text) hash_set = set() couplets = [] def append(one, two): one, two = one.strip(), two.strip() if is_unicode(one) or is_unicode(two): return one_hash, two_hash = hash(one), hash(two) if one_hash not in hash_set and two_hash not in hash_set: one = ''.join(one.split()) two = ''.join(two.split()) couplets.append((one, two)) hash_set.add(one_hash) hash_set.add(two_hash) fnames = (('couplet/train/in.txt', 'couplet/train/out.txt'), ('couplet/test/in.txt', 'couplet/test/out.txt')) for one_fname, two_fname in fnames: for one, two in zip(open(one_fname), open(two_fname)): append(one, two) # 对联数量: 563114 print('对联数量:', len(couplets)) pickle.dump(couplets[:-5], open('data/couplet-train.pkl', 'wb')) # [('日里千人拱手划船,齐歌狂吼川江号子', '夜间百舸点灯敬佛,漫饮轻吟巴岳清音'), # ('入迷途,吞苦果,回头是岸', '到此处,改前非,革面做人'), # ('地近秦淮,看碧水蓝天,一行白鹭飞来何处', '门临闹市,入红楼翠馆,四海旅人宾至如归'), # ('其巧在古倕以上', '所居介帝君之间'), # ('万众齐心,已膺全国文明市', '千帆竞发,再鼓鹭江经济潮')] print(couplets[-5:]) def demo02(): index_to_word, word_to_index = {}, {} couplets = pickle.load(open('data/couplet-train.pkl', 'rb')) words = set() for one, two in couplets: words.update(list(one)) words.update(list(two)) special_tokens = ['[PAD]', '[BEG]', '[END]', '[UNK]'] words = special_tokens + list(words) # 词表大小: 8794 print('词表大小:', len(words)) open('data/vocab.txt', 'w').writelines('\n'.join(words)) if __name__ == '__main__': demo01() demo02()
2. 模型训练
from transformers import BertTokenizer from transformers import LlamaConfig from transformers import LlamaForCausalLM from pytorch_lightning import LightningModule from pytorch_lightning import Trainer from torch.utils.data import DataLoader import torch.nn as nn import torch torch.set_float32_matmul_precision('medium') import pickle import shutil import os class CoupletGenerator(LightningModule): def __init__(self): super(CoupletGenerator, self).__init__() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.tokenizer = BertTokenizer(vocab_file='data/vocab.txt', unk_token='[UNK]', pad_token='[PAD]') self.tokenizer.add_special_tokens({'additional_special_tokens': ['[BEG]', '[END]']}) config = LlamaConfig(vocab_size=self.tokenizer.vocab_size, hidden_size=256, num_hidden_layers=2, num_attention_heads=2, pad_token_id=self.tokenizer.convert_tokens_to_ids('[PAD]'), eos_token_id=self.tokenizer.convert_tokens_to_ids('[END]'), max_position_embeddings=128) self.estimator = LlamaForCausalLM(config=config).to(device) def prepare_inputs(self, batch_datas): batch_inputs, batch_labels = [], [] for upper_line, lower_line in batch_datas: inputs = upper_line + '[BEG]' + lower_line + '[END]' batch_inputs.append(''.join(inputs)) batch_labels.append(''.join(inputs)) # 输入编码 batch_inputs = self.tokenizer.batch_encode_plus(batch_inputs, add_special_tokens=False, return_length=True, return_tensors='pt', padding=True, return_token_type_ids=False) # 输入长度 batch_size = torch.sum(batch_inputs['length']) batch_inputs.pop('length') # 输入标签 batch_inputs['labels'] = batch_inputs['input_ids'].clone() batch_inputs = { k: v.to(self.estimator.device) for k, v in batch_inputs.items() } return batch_inputs, batch_size def training_step(self, batch_datas, batch_index): batch_inputs, batch_size = self.prepare_inputs(batch_datas) model_result = self.estimator(**batch_inputs, ignore_index=self.tokenizer.convert_tokens_to_ids('[PAD]')) # 记录总损失 self.total_loss += model_result.loss.item() * batch_size self.total_size += batch_size self.log('Loss', self.total_loss / self.total_size, prog_bar=True) return model_result.loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.estimator.parameters(), lr=1e-3) return optimizer def save_pretrained(self): fname = f'checkpoints/couplet-{self.current_epoch}' if os.path.exists(fname): shutil.rmtree(fname) os.mkdir(fname) self.estimator.save_pretrained(fname) self.tokenizer.save_pretrained(fname) @classmethod def from_pretrained(cls, checkpoint): generator = cls() generator.estimator.from_pretrained(fname) generator.tokenizer.from_pretrained(fname) return generator def generate(self, upper_lines): upper_lines = [ upper_line + '[BEG]' for upper_line in upper_lines ] batch_inputs = self.tokenizer.batch_encode_plus(upper_lines, add_special_tokens=False, padding=True, return_tensors='pt', return_length=True, return_token_type_ids=False) # 移动张量到模型设备上 batch_inputs = {k: v.to(self.estimator.device) for k, v in batch_inputs.items()} upper_lens = batch_inputs['length'] batch_inputs.pop('length') with torch.no_grad(): lower_lines = self.estimator.generate(**batch_inputs, max_length=128, eos_token_id=self.tokenizer.convert_tokens_to_ids('[END]'), pad_token_id=self.tokenizer.convert_tokens_to_ids('[PAD]')) # 解码生成的下联 for upper_line, lower_line, upper_len in zip(upper_lines, lower_lines, upper_lens): lower_line = lower_line[upper_len:] lower_line = self.tokenizer.decode(lower_line, skip_special_tokens=False) print(f"上联: {upper_line}\n下联: {''.join(lower_line.split())}\n") print('-' * 100) def on_train_epoch_start(self): self.total_loss = 0 self.total_size = 0 def on_train_epoch_end(self): couplets = [('日里千人拱手划船,齐歌狂吼川江号子', '夜间百舸点灯敬佛,漫饮轻吟巴岳清音'), ('入迷途,吞苦果,回头是岸', '到此处,改前非,革面做人'), ('地近秦淮,看碧水蓝天,一行白鹭飞来何处', '门临闹市,入红楼翠馆,四海旅人宾至如归'), ('其巧在古倕以上', '所居介帝君之间'), ('万众齐心,已膺全国文明市', '千帆竞发,再鼓鹭江经济潮')] for upper_line, _ in couplets: self.generate([upper_line]) # 保存模型 self.save_pretrained() def train(): train_data = pickle.load(open('data/couplet-train.pkl', 'rb')) dataloader = DataLoader(train_data, batch_size=128, num_workers=6, persistent_workers=True, shuffle=True, collate_fn=lambda batch_data: batch_data) estimator = CoupletGenerator() trainer = Trainer(max_epochs=10, logger=False, enable_checkpointing=False) trainer.fit(estimator, dataloader) if __name__ == '__main__': train()