实现 Seq2Seq + Attention 生成对联

对联是中国传统文化中的一项独特艺术形式,它不仅要求上下句字数相同,还要对仗工整、意义相对。随着人工智能和自然语言处理技术的进步,如何让机器自动生成符合对联规律的文本,变得越来越有趣也越来越可行。接下来,我们将一起探讨如何搭建一个 Seq2Seq + Attention 架构的模型,自动根据输入的上联生成对仗工整的下联。

网络中的编码器和解码器部分我们选择使用 GRU 模型,GRU 是一种比传统 RNN 更加高效的递归神经网络结构,能够有效解决长序列中的梯度消失问题。结合 Seq2Seq 架构,模型可以将输入的序列转换成输出序列,从而生成流畅连贯的文本。再加上 Attention 机制,让模型在生成文本时,能聚焦于输入中最重要的部分,从而提高生成文本的质量和准确性。

接下来,我们会从数据准备、模型搭建、训练生成,逐步实现每个环节。

1. 数据处理

我从互联网搜索到一个包含 70+ 万对联数据的语料,后面的模型训练也是基于该语料进行训练。下面,我们先对数据做一下处理,主要是去重和构建词表。

import pickle
import re
from tqdm import tqdm


# 对联去重
def demo01():
    is_unicode = lambda text: any('\ue000' <= char <= '\uf8ff' for char in text)

    couplets, hash_set = [], set()
    def append(one, two):
        one, two = one.strip(), two.strip()
        if is_unicode(one) or is_unicode(two):
            return
        one_hash, two_hash = hash(one), hash(two)
        if one_hash not in hash_set and two_hash not in hash_set:
            couplets.append((one, two))
            hash_set.add(one_hash)
            hash_set.add(two_hash)

    fnames = (('couplet/train/in.txt', 'couplet/train/out.txt'), ('couplet/test/in.txt', 'couplet/test/out.txt'))
    for one_fname, two_fname in fnames:
        for one, two in zip(open(one_fname), open(two_fname)):
            append(one, two)

    # 对联数量: 563114
    print('对联数量:', len(couplets))

    pickle.dump(couplets[:-5], open('data/couplet-train.pkl', 'wb'))
    # [('日 里 千 人 拱 手 划 船 , 齐 歌 狂 吼 川 江 号 子', '夜 间 百 舸 点 灯 敬 佛 , 漫 饮 轻 吟 巴 岳 清 音'),
    # ('入 迷 途 , 吞 苦 果 , 回 头 是 岸', '到 此 处 , 改 前 非 , 革 面 做 人'),
    # ('地 近 秦 淮 , 看 碧 水 蓝 天 , 一 行 白 鹭 飞 来 何 处', '门 临 闹 市 , 入 红 楼 翠 馆 , 四 海 旅 人 宾 至 如 归'),
    # ('其 巧 在 古 倕 以 上', '所 居 介 帝 君 之 间'),
    # ('万 众 齐 心 , 已 膺 全 国 文 明 市', '千 帆 竞 发 , 再 鼓 鹭 江 经 济 潮')]
    print(couplets[-5:])



def demo02():
    index_to_word, word_to_index = {}, {}
    couplets = pickle.load(open('data/couplet-train.pkl', 'rb'))

    words = set()
    for one, two in couplets:
        words.update(one.split())
        words.update(two.split())

    word_to_index = {'[PAD]': 0, '[BEG]': 1, '[END]': 2, '[UNK]': 3}
    index_to_word = {0: '[PAD]', 1: '[BEG]', 2: '[END]', 3: '[UNK]'}

    for idx, word in enumerate(words, start=len(word_to_index)):
        if word not in word_to_index:
            word_to_index[word] = idx
            index_to_word[idx] = word

    # 词表大小: 8794
    print('词表大小:', len(word_to_index))

    pickle.dump(word_to_index, open('data/word_to_index.pkl', 'wb'))
    pickle.dump(index_to_word, open('data/index_to_word.pkl', 'wb'))


if __name__ == '__main__':
    demo01()
    demo02()

2. 分词器

import pickle
import torch
from torch.nn.utils.rnn import pad_sequence


class Tokenzier:

    def __init__(self):
        self.word_to_index = pickle.load(open('data/word_to_index.pkl', 'rb'))
        self.index_to_word = pickle.load(open('data/index_to_word.pkl', 'rb'))
        self.unk = self.word_to_index['[UNK]']
        self.pad = self.word_to_index['[PAD]']

    def get_vocab_size(self):
        return len(self.word_to_index)

    def encode(self, texts):
        batch_ids, batch_len = [], []
        for text in texts:
            ids = []
            for word in text.split():
                if word in self.word_to_index:
                    index = self.word_to_index[word]
                else:
                    index = self.unk
                ids.append(index)
            batch_ids.append(torch.tensor(ids))
            batch_len.append(len(ids))

        # 将批次数据 PAD 对齐
        batch_ids = pad_sequence(batch_ids, batch_first=True, padding_value=self.pad)
        batch_len = torch.tensor(batch_len)
        # 降序排列
        indices = torch.argsort(batch_len, descending=True)
        batch_ids = batch_ids[indices]
        batch_len = batch_len[indices]

        return batch_ids, batch_len

    def decode(self, token_ids):
        content = ''
        for token_id in token_ids:
            content += self.index_to_word[token_id]
        return content


if __name__ == '__main__':
    tokenizer = Tokenzier()
    ones = ['晚 风 摇 树 树 还 挺', '愿 景 天 成 无 墨 迹 墨 迹']
    twos = ['[BEG] 晨 露 润 花 花 更 红 [END]', '[BEG] 万 方 乐 奏 有 于 阗 [END]']
    batch_ids, batch_len = tokenizer.encode(ones)
    print(batch_ids)

    batch_ids, batch_len = tokenizer.encode(twos)
    print(batch_ids)

3. 模型搭建

import torch.nn as nn
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence


class CoupletGenerator(nn.Module):

    def __init__(self, vocab_size):
        super(CoupletGenerator, self).__init__()
        embedd_size, hidden_size = 128, 256
        self.vectors = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedd_size)
        self.encoder = nn.GRU(input_size=embedd_size, hidden_size=hidden_size, batch_first=True)
        self.decoder = nn.GRU(input_size=embedd_size, hidden_size=hidden_size, batch_first=True)
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=1, batch_first=True)
        self.outputs = nn.Linear(in_features=hidden_size, out_features=vocab_size)

    def generate(self, upper_line, max_length=20, beg_token_id=1, eos_token_id=2):
        """对联生成函数"""
        device = next(self.parameters()).device
        upper_line = upper_line.to(device)
        upper_embed = self.vectors(upper_line)
        upper_token_vectors, upper_last_hidden = self.encoder(upper_embed)

        start_token = torch.tensor([[beg_token_id]], device=device)
        start_embed = self.vectors(start_token)
        hx = upper_last_hidden
        lower_tokens = []
        for _ in range(max_length):
            # 编码器理解上联
            lower_token_vector, hx = self.decoder(start_embed, hx)
            # 进行注意力计算
            lower_attention_infos, _ = self.attention(query=lower_token_vector,
                                                      key=upper_token_vectors,
                                                      value=upper_token_vectors)
            lower_token_vector + lower_attention_infos
            # 预测当前的标记
            lower_logits = self.outputs(lower_token_vector)
            lower_next_token = torch.argmax(lower_logits, dim=-1)
            lower_tokens.append(lower_next_token.item())
            # 输出 [END],则停止生成
            if lower_next_token.item() == eos_token_id:
                break
            start_embed = self.vectors(lower_next_token)

        return lower_tokens


    def encoder_understand(self, upper_lines, upper_lens):
        """编码器理解上联"""
        upper_embed = self.vectors(upper_lines)
        upper_embed_packed = pack_padded_sequence(upper_embed, lengths=upper_lens, batch_first=True, enforce_sorted=False)
        upper_token_vectors, upper_last_hidden = self.encoder(upper_embed_packed)
        upper_token_vectors, upper_lens = pad_packed_sequence(upper_token_vectors, batch_first=True)
        return upper_token_vectors, upper_last_hidden


    def decoder_generation(self, lower_lines, lower_lens, upper_last_hidden):
        """解码器生成下联"""
        lower_embed = self.vectors(lower_lines)
        lower_embed_packed = pack_padded_sequence(lower_embed, lengths=lower_lens, batch_first=True, enforce_sorted=False)
        lower_token_vectors, lower_last_hidden = self.decoder(lower_embed_packed, upper_last_hidden)
        lower_token_vectors, lower_lens = pad_packed_sequence(lower_token_vectors, batch_first=True)

        return lower_token_vectors


    def forward(self, upper_lines, upper_lens, lower_lines, lower_lens, ignore_index=0):
        """前向计算函数"""
        device = next(self.parameters()).device
        # 输入数据移动到模型所在设备上
        upper_lines, lower_lines = upper_lines.to(device), lower_lines.to(device)
        # 编码器对上联理解
        upper_token_vectors, upper_last_hidden = self.encoder_understand(upper_lines, upper_lens)
        # 解码器生成下联
        decoder_labels = lower_lines[:, 1:]
        decoder_inputs = lower_lines[:, :-1]
        lower_token_vectors = self.decoder_generation(decoder_inputs, lower_lens - 1, upper_last_hidden)
        # 上联是一个批次,创建 mask,避免 PAD 参与注意力计算(True 掩码,表示不参与,False 不掩码,表示参与计算)

        key_padding_mask = (upper_lines == 0)
        lower_attention_vectors, weights = self.attention(query=lower_token_vectors,
                                                          key=upper_token_vectors,
                                                          value=upper_token_vectors,
                                                          key_padding_mask=key_padding_mask)
        # 生成 token 时,结合输入 token 的信息
        lower_token_vectors += lower_attention_vectors
        logits = self.outputs(lower_token_vectors)
        # 计算交叉熵损失,标记为 ignore_index 的 token 不计算损失
        criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
        lower_loss = criterion(logits.view(-1, self.vectors.num_embeddings), decoder_labels.reshape(-1))

        return logits, lower_loss


if __name__ == '__main__':
    from tokenizer import Tokenzier
    tokenizer = Tokenzier()
    upper_lines = ['晚 风 摇 树 树 还 挺', '愿 景 天 成 无 墨 迹 墨 迹']
    lower_lines = ['[BEG] 晨 露 润 花 花 更 红 [END]', '[BEG] 万 方 乐 奏 有 于 阗 [END]']
    upper_lines, upper_lens = tokenizer.encode(upper_lines)
    lower_lines, lower_lens = tokenizer.encode(lower_lines)

    estimator = CoupletGenerator(vocab_size=tokenizer.get_vocab_size())
    logits, loss = estimator(upper_lines, upper_lens, lower_lines, lower_lens)

    upper_lines, upper_lens = tokenizer.encode(['晚 风 摇 树 树 还 挺', ])
    text_tokens = estimator.generate(upper_lines)
    print(text_tokens)

    text = tokenizer.decode(text_tokens)
    print(text)

4. 模型训练

from tokenizer import Tokenzier
from torch.utils.data import DataLoader
import pickle
from tqdm import tqdm
from couplet_generator import CoupletGenerator
from tokenizer import Tokenzier
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR



def train(epochs=10, device='cuda' if torch.cuda.is_available() else 'cpu'):
    tokenizer = Tokenzier()
    estimator = CoupletGenerator(vocab_size=tokenizer.get_vocab_size()).to(device)
    optimizer = optim.Adam(estimator.parameters(), lr=1e-3)
    scheduler = StepLR(optimizer, step_size=5000, gamma=0.9)

    train_data = pickle.load(open('data/couplet-train.pkl', 'rb'))
    def collate_fn(batch_data):
        upper_lines, lower_lines = [], []
        for upper_line, lower_line in batch_data:
            upper_lines.append(upper_line)
            lower_lines.append('[BEG] ' + lower_line + ' [END]')
        upper_lines, upper_lens = tokenizer.encode(upper_lines)
        lower_lines, lower_lens = tokenizer.encode(lower_lines)
        return upper_lines, upper_lens, lower_lines, lower_lens

    dataloader = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=collate_fn)

    record_loss = []
    for epoch in range(epochs):
        estimator.train()
        total_loss, total_size = 0, 0
        progress = tqdm(range(len(dataloader)), ncols=100)
        for upper_lines, upper_lens, lower_lines, lower_lens in dataloader:
            # 移动输入数据到设备上
            upper_lines, lower_lines = upper_lines.to(device), lower_lines.to(device)
            # 梯度清零
            optimizer.zero_grad()
            # 损失计算
            logits, loss = estimator(upper_lines, upper_lens, lower_lines, lower_lens)
            # 梯度计算
            loss.backward()
            # 参数更新
            optimizer.step()
            # 学习率更新
            scheduler.step()
            # 损失统计
            total_loss += loss.item() * torch.sum(lower_lens - 1)
            total_size += torch.sum(lower_lens - 1)
            progress.set_description(f'Epoch {epoch + 1}/{epochs}, Loss: {total_loss / total_size:.4f} Lr: {scheduler.get_last_lr()[0]:.6f}')
            progress.update()
        progress.close()
        record_loss.append(total_loss / total_size)
        # 存储模型
        pickle.dump(estimator, open(f'model/couplet-estimator-{epoch + 1}.pkl', 'wb'))
        pickle.dump(tokenizer, open(f'model/couplet-tokenizer-{epoch + 1}.pkl', 'wb'))

    plt.plot(range(len(record_loss)), record_loss)
    plt.title('Loss Curve')
    plt.grid()
    plt.show()


if __name__ == '__main__':
    train()
Epoch 1/10, Loss: 4.3985 Lr: 0.001000: 100%|████████████████████| 8799/8799 [01:43<00:00, 85.33it/s]
Epoch 2/10, Loss: 3.8814 Lr: 0.000900: 100%|████████████████████| 8799/8799 [01:42<00:00, 85.51it/s]
Epoch 3/10, Loss: 3.7426 Lr: 0.000810: 100%|████████████████████| 8799/8799 [01:42<00:00, 85.43it/s]
Epoch 4/10, Loss: 3.6580 Lr: 0.000729: 100%|████████████████████| 8799/8799 [01:43<00:00, 85.33it/s]
Epoch 5/10, Loss: 3.5968 Lr: 0.000656: 100%|████████████████████| 8799/8799 [01:43<00:00, 85.29it/s]
Epoch 6/10, Loss: 3.5489 Lr: 0.000590: 100%|████████████████████| 8799/8799 [01:43<00:00, 85.42it/s]
Epoch 7/10, Loss: 3.5099 Lr: 0.000531: 100%|████████████████████| 8799/8799 [01:43<00:00, 85.35it/s]
Epoch 8/10, Loss: 3.4764 Lr: 0.000478: 100%|████████████████████| 8799/8799 [01:43<00:00, 85.31it/s]
Epoch 9/10, Loss: 3.4479 Lr: 0.000430: 100%|████████████████████| 8799/8799 [01:43<00:00, 85.38it/s]
Epoch 10/10, Loss: 3.4228 Lr: 0.000387: 100%|███████████████████| 8799/8799 [01:43<00:00, 85.37it/s]

5. 对联生成

import pickle
import torch


def inference():

    estimator = pickle.load(open(f'model/couplet-estimator-7.pkl', 'rb'))
    tokenizer = pickle.load(open(f'model/couplet-tokenizer-7.pkl', 'rb'))

    upper_lines = [('日 里 千 人 拱 手 划 船 , 齐 歌 狂 吼 川 江 号 子', '夜 间 百 舸 点 灯 敬 佛 , 漫 饮 轻 吟 巴 岳 清 音'),
                   ('入 迷 途 , 吞 苦 果 , 回 头 是 岸', '到 此 处 , 改 前 非 , 革 面 做 人'),
                   ('地 近 秦 淮 , 看 碧 水 蓝 天 , 一 行 白 鹭 飞 来 何 处', '门 临 闹 市 , 入 红 楼 翠 馆 , 四 海 旅 人 宾 至 如 归'),
                   ('其 巧 在 古 倕 以 上', '所 居 介 帝 君 之 间'),
                   ('万 众 齐 心 , 已 膺 全 国 文 明 市', '千 帆 竞 发 , 再 鼓 鹭 江 经 济 潮')]

    for upper_line, lower_line in upper_lines:
        upper_ids, _ = tokenizer.encode([upper_line])
        with torch.no_grad():
            generate_lower_line = estimator.generate(upper_ids)
        generate_lower_line = tokenizer.decode(generate_lower_line)
        print('输入上联:', ''.join(upper_line.split()))
        print('参考下联:', ''.join(lower_line.split()))
        print('生成下联:', ''.join(generate_lower_line.split()))
        print('-' * 50)


if __name__ == '__main__':
    inference()
输入上联: 日里千人拱手划船,齐歌狂吼川江号子
参考下联: 夜间百舸点灯敬佛,漫饮轻吟巴岳清音
生成下联: 天中万物归心向日,共建和谐社会文明[END]
--------------------------------------------------
输入上联: 入迷途,吞苦果,回头是岸
参考下联: 到此处,改前非,革面做人
生成下联: 出水水,流光花,出手为天[END]
--------------------------------------------------
输入上联: 地近秦淮,看碧水蓝天,一行白鹭飞来何处
参考下联: 门临闹市,入红楼翠馆,四海旅人宾至如归
生成下联: 天高云汉,听黄河碧浪,千里青山隐去无边[END]
--------------------------------------------------
输入上联: 其巧在古倕以上
参考下联: 所居介帝君之间
生成下联: 斯文于武城而后[END]
--------------------------------------------------
输入上联: 万众齐心,已膺全国文明市
参考下联: 千帆竞发,再鼓鹭江经济潮
生成下联: 千秋大业,再绘宏图锦绣图[END]
--------------------------------------------------

未经允许不得转载:一亩三分地 » 实现 Seq2Seq + Attention 生成对联