对联生成作为一种文本生成任务,通常要求生成内容具有对称性和韵律。过去,我们多采用基于预训练模型的微调方法来完成这一任务。这一次,我们尝试使用较小尺寸的 Llama 模型,从零开始进行训练,即:把 Llama 当做 GRU、LSTM 一样来使用。
1. 构建词表
import pickle
import re
from tqdm import tqdm
def demo01():
is_unicode = lambda text: any('\ue000' <= char <= '\uf8ff' for char in text)
hash_set = set()
couplets = []
def append(one, two):
one, two = one.strip(), two.strip()
if is_unicode(one) or is_unicode(two):
return
one_hash, two_hash = hash(one), hash(two)
if one_hash not in hash_set and two_hash not in hash_set:
one = ''.join(one.split())
two = ''.join(two.split())
couplets.append((one, two))
hash_set.add(one_hash)
hash_set.add(two_hash)
fnames = (('couplet/train/in.txt', 'couplet/train/out.txt'), ('couplet/test/in.txt', 'couplet/test/out.txt'))
for one_fname, two_fname in fnames:
for one, two in zip(open(one_fname), open(two_fname)):
append(one, two)
# 对联数量: 563114
print('对联数量:', len(couplets))
pickle.dump(couplets[:-5], open('data/couplet-train.pkl', 'wb'))
# [('日里千人拱手划船,齐歌狂吼川江号子', '夜间百舸点灯敬佛,漫饮轻吟巴岳清音'),
# ('入迷途,吞苦果,回头是岸', '到此处,改前非,革面做人'),
# ('地近秦淮,看碧水蓝天,一行白鹭飞来何处', '门临闹市,入红楼翠馆,四海旅人宾至如归'),
# ('其巧在古倕以上', '所居介帝君之间'),
# ('万众齐心,已膺全国文明市', '千帆竞发,再鼓鹭江经济潮')]
print(couplets[-5:])
def demo02():
index_to_word, word_to_index = {}, {}
couplets = pickle.load(open('data/couplet-train.pkl', 'rb'))
words = set()
for one, two in couplets:
words.update(list(one))
words.update(list(two))
special_tokens = ['[PAD]', '[BEG]', '[END]', '[UNK]']
words = special_tokens + list(words)
# 词表大小: 8794
print('词表大小:', len(words))
open('data/vocab.txt', 'w').writelines('\n'.join(words))
if __name__ == '__main__':
demo01()
demo02()
2. 模型训练
from transformers import BertTokenizer
from transformers import LlamaConfig
from transformers import LlamaForCausalLM
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
torch.set_float32_matmul_precision('medium')
import pickle
import shutil
import os
class CoupletGenerator(LightningModule):
def __init__(self):
super(CoupletGenerator, self).__init__()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = BertTokenizer(vocab_file='data/vocab.txt', unk_token='[UNK]', pad_token='[PAD]')
self.tokenizer.add_special_tokens({'additional_special_tokens': ['[BEG]', '[END]']})
config = LlamaConfig(vocab_size=self.tokenizer.vocab_size,
hidden_size=256,
num_hidden_layers=2,
num_attention_heads=2,
pad_token_id=self.tokenizer.convert_tokens_to_ids('[PAD]'),
eos_token_id=self.tokenizer.convert_tokens_to_ids('[END]'),
max_position_embeddings=128)
self.estimator = LlamaForCausalLM(config=config).to(device)
def prepare_inputs(self, batch_datas):
batch_inputs, batch_labels = [], []
for upper_line, lower_line in batch_datas:
inputs = upper_line + '[BEG]' + lower_line + '[END]'
batch_inputs.append(''.join(inputs))
batch_labels.append(''.join(inputs))
# 输入编码
batch_inputs = self.tokenizer.batch_encode_plus(batch_inputs,
add_special_tokens=False,
return_length=True,
return_tensors='pt',
padding=True,
return_token_type_ids=False)
# 输入长度
batch_size = torch.sum(batch_inputs['length'])
batch_inputs.pop('length')
# 输入标签
batch_inputs['labels'] = batch_inputs['input_ids'].clone()
batch_inputs = { k: v.to(self.estimator.device) for k, v in batch_inputs.items() }
return batch_inputs, batch_size
def training_step(self, batch_datas, batch_index):
batch_inputs, batch_size = self.prepare_inputs(batch_datas)
model_result = self.estimator(**batch_inputs, ignore_index=self.tokenizer.convert_tokens_to_ids('[PAD]'))
# 记录总损失
self.total_loss += model_result.loss.item() * batch_size
self.total_size += batch_size
self.log('Loss', self.total_loss / self.total_size, prog_bar=True)
return model_result.loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.estimator.parameters(), lr=1e-3)
return optimizer
def save_pretrained(self):
fname = f'checkpoints/couplet-{self.current_epoch}'
if os.path.exists(fname):
shutil.rmtree(fname)
os.mkdir(fname)
self.estimator.save_pretrained(fname)
self.tokenizer.save_pretrained(fname)
@classmethod
def from_pretrained(cls, checkpoint):
generator = cls()
generator.estimator.from_pretrained(fname)
generator.tokenizer.from_pretrained(fname)
return generator
def generate(self, upper_lines):
upper_lines = [ upper_line + '[BEG]' for upper_line in upper_lines ]
batch_inputs = self.tokenizer.batch_encode_plus(upper_lines,
add_special_tokens=False,
padding=True,
return_tensors='pt',
return_length=True,
return_token_type_ids=False)
# 移动张量到模型设备上
batch_inputs = {k: v.to(self.estimator.device) for k, v in batch_inputs.items()}
upper_lens = batch_inputs['length']
batch_inputs.pop('length')
with torch.no_grad():
lower_lines = self.estimator.generate(**batch_inputs,
max_length=128,
eos_token_id=self.tokenizer.convert_tokens_to_ids('[END]'),
pad_token_id=self.tokenizer.convert_tokens_to_ids('[PAD]'))
# 解码生成的下联
for upper_line, lower_line, upper_len in zip(upper_lines, lower_lines, upper_lens):
lower_line = lower_line[upper_len:]
lower_line = self.tokenizer.decode(lower_line, skip_special_tokens=False)
print(f"上联: {upper_line}\n下联: {''.join(lower_line.split())}\n")
print('-' * 100)
def on_train_epoch_start(self):
self.total_loss = 0
self.total_size = 0
def on_train_epoch_end(self):
couplets = [('日里千人拱手划船,齐歌狂吼川江号子', '夜间百舸点灯敬佛,漫饮轻吟巴岳清音'),
('入迷途,吞苦果,回头是岸', '到此处,改前非,革面做人'),
('地近秦淮,看碧水蓝天,一行白鹭飞来何处', '门临闹市,入红楼翠馆,四海旅人宾至如归'),
('其巧在古倕以上', '所居介帝君之间'),
('万众齐心,已膺全国文明市', '千帆竞发,再鼓鹭江经济潮')]
for upper_line, _ in couplets:
self.generate([upper_line])
# 保存模型
self.save_pretrained()
def train():
train_data = pickle.load(open('data/couplet-train.pkl', 'rb'))
dataloader = DataLoader(train_data,
batch_size=128,
num_workers=6,
persistent_workers=True,
shuffle=True,
collate_fn=lambda batch_data: batch_data)
estimator = CoupletGenerator()
trainer = Trainer(max_epochs=10, logger=False, enable_checkpointing=False)
trainer.fit(estimator, dataloader)
if __name__ == '__main__':
train()

冀公网安备13050302001966号