直接训练 Llama 生成对联

对联生成作为一种文本生成任务,通常要求生成内容具有对称性和韵律。过去,我们多采用基于预训练模型的微调方法来完成这一任务。这一次,我们尝试使用较小尺寸的 Llama 模型,从零开始进行训练,即:把 Llama 当做 GRU、LSTM 一样来使用。

1. 构建词表

import pickle
import re
from tqdm import tqdm


def demo01():
    is_unicode = lambda text: any('\ue000' <= char <= '\uf8ff' for char in text)

    hash_set = set()
    couplets = []
    def append(one, two):
        one, two = one.strip(), two.strip()
        if is_unicode(one) or is_unicode(two):
            return
        one_hash, two_hash = hash(one), hash(two)
        if one_hash not in hash_set and two_hash not in hash_set:
            one = ''.join(one.split())
            two = ''.join(two.split())
            couplets.append((one, two))
            hash_set.add(one_hash)
            hash_set.add(two_hash)

    fnames = (('couplet/train/in.txt', 'couplet/train/out.txt'), ('couplet/test/in.txt', 'couplet/test/out.txt'))
    for one_fname, two_fname in fnames:
        for one, two in zip(open(one_fname), open(two_fname)):
            append(one, two)

    # 对联数量: 563114
    print('对联数量:', len(couplets))

    pickle.dump(couplets[:-5], open('data/couplet-train.pkl', 'wb'))
    # [('日里千人拱手划船,齐歌狂吼川江号子', '夜间百舸点灯敬佛,漫饮轻吟巴岳清音'),
    # ('入迷途,吞苦果,回头是岸', '到此处,改前非,革面做人'),
    # ('地近秦淮,看碧水蓝天,一行白鹭飞来何处', '门临闹市,入红楼翠馆,四海旅人宾至如归'),
    # ('其巧在古倕以上', '所居介帝君之间'),
    # ('万众齐心,已膺全国文明市', '千帆竞发,再鼓鹭江经济潮')]
    print(couplets[-5:])


def demo02():
    index_to_word, word_to_index = {}, {}
    couplets = pickle.load(open('data/couplet-train.pkl', 'rb'))

    words = set()
    for one, two in couplets:
        words.update(list(one))
        words.update(list(two))

    special_tokens = ['[PAD]', '[BEG]', '[END]', '[UNK]']
    words = special_tokens + list(words)
    # 词表大小: 8794
    print('词表大小:', len(words))
    open('data/vocab.txt', 'w').writelines('\n'.join(words))


if __name__ == '__main__':
    demo01()
    demo02()

2. 模型训练

from transformers import BertTokenizer
from transformers import LlamaConfig
from transformers import LlamaForCausalLM
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
torch.set_float32_matmul_precision('medium')

import pickle
import shutil
import os


class CoupletGenerator(LightningModule):

    def __init__(self):
        super(CoupletGenerator, self).__init__()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = BertTokenizer(vocab_file='data/vocab.txt', unk_token='[UNK]', pad_token='[PAD]')
        self.tokenizer.add_special_tokens({'additional_special_tokens': ['[BEG]', '[END]']})

        config = LlamaConfig(vocab_size=self.tokenizer.vocab_size,
                             hidden_size=256,
                             num_hidden_layers=2,
                             num_attention_heads=2,
                             pad_token_id=self.tokenizer.convert_tokens_to_ids('[PAD]'),
                             eos_token_id=self.tokenizer.convert_tokens_to_ids('[END]'),
                             max_position_embeddings=128)
        self.estimator = LlamaForCausalLM(config=config).to(device)


    def prepare_inputs(self, batch_datas):
        batch_inputs, batch_labels = [], []
        for upper_line, lower_line in batch_datas:
            inputs = upper_line + '[BEG]' + lower_line + '[END]'
            batch_inputs.append(''.join(inputs))
            batch_labels.append(''.join(inputs))
        # 输入编码
        batch_inputs = self.tokenizer.batch_encode_plus(batch_inputs,
                                                        add_special_tokens=False,
                                                        return_length=True,
                                                        return_tensors='pt',
                                                        padding=True,
                                                        return_token_type_ids=False)
        # 输入长度
        batch_size = torch.sum(batch_inputs['length'])
        batch_inputs.pop('length')
        # 输入标签
        batch_inputs['labels'] = batch_inputs['input_ids'].clone()
        batch_inputs = { k: v.to(self.estimator.device) for k, v in batch_inputs.items() }
        return batch_inputs, batch_size


    def training_step(self, batch_datas, batch_index):
        batch_inputs, batch_size = self.prepare_inputs(batch_datas)
        model_result = self.estimator(**batch_inputs, ignore_index=self.tokenizer.convert_tokens_to_ids('[PAD]'))

        # 记录总损失
        self.total_loss += model_result.loss.item() * batch_size
        self.total_size += batch_size
        self.log('Loss', self.total_loss / self.total_size, prog_bar=True)
        return model_result.loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.estimator.parameters(), lr=1e-3)
        return optimizer

    def save_pretrained(self):
        fname = f'checkpoints/couplet-{self.current_epoch}'
        if os.path.exists(fname):
            shutil.rmtree(fname)
        os.mkdir(fname)
        self.estimator.save_pretrained(fname)
        self.tokenizer.save_pretrained(fname)

    @classmethod
    def from_pretrained(cls, checkpoint):
        generator = cls()
        generator.estimator.from_pretrained(fname)
        generator.tokenizer.from_pretrained(fname)
        return generator

    def generate(self, upper_lines):
        upper_lines = [ upper_line + '[BEG]' for upper_line in upper_lines ]
        batch_inputs = self.tokenizer.batch_encode_plus(upper_lines,
                                                        add_special_tokens=False,
                                                        padding=True,
                                                        return_tensors='pt',
                                                        return_length=True,
                                                        return_token_type_ids=False)
        # 移动张量到模型设备上
        batch_inputs = {k: v.to(self.estimator.device) for k, v in batch_inputs.items()}
        upper_lens = batch_inputs['length']
        batch_inputs.pop('length')
        with torch.no_grad():
            lower_lines = self.estimator.generate(**batch_inputs,
                                                  max_length=128,
                                                  eos_token_id=self.tokenizer.convert_tokens_to_ids('[END]'),
                                                  pad_token_id=self.tokenizer.convert_tokens_to_ids('[PAD]'))

        # 解码生成的下联
        for upper_line, lower_line, upper_len in zip(upper_lines, lower_lines, upper_lens):
            lower_line = lower_line[upper_len:]
            lower_line = self.tokenizer.decode(lower_line, skip_special_tokens=False)
            print(f"上联: {upper_line}\n下联: {''.join(lower_line.split())}\n")

        print('-' * 100)

    def on_train_epoch_start(self):
        self.total_loss = 0
        self.total_size = 0

    def on_train_epoch_end(self):
        couplets = [('日里千人拱手划船,齐歌狂吼川江号子', '夜间百舸点灯敬佛,漫饮轻吟巴岳清音'),
                    ('入迷途,吞苦果,回头是岸', '到此处,改前非,革面做人'),
                    ('地近秦淮,看碧水蓝天,一行白鹭飞来何处', '门临闹市,入红楼翠馆,四海旅人宾至如归'),
                    ('其巧在古倕以上', '所居介帝君之间'),
                    ('万众齐心,已膺全国文明市', '千帆竞发,再鼓鹭江经济潮')]
        for upper_line, _ in couplets:
            self.generate([upper_line])
        # 保存模型
        self.save_pretrained()


def train():
    train_data = pickle.load(open('data/couplet-train.pkl', 'rb'))
    dataloader = DataLoader(train_data,
                            batch_size=128,
                            num_workers=6,
                            persistent_workers=True,
                            shuffle=True,
                            collate_fn=lambda batch_data: batch_data)
    estimator = CoupletGenerator()

    trainer = Trainer(max_epochs=10, logger=False, enable_checkpointing=False)
    trainer.fit(estimator, dataloader)


if __name__ == '__main__':
    train()

未经允许不得转载:一亩三分地 » 直接训练 Llama 生成对联