基于 GPT 训练对联模型 – 模型预测

我们主要实现了两个预测函数,一个用于给定第一个字,来预测上联和下联,另外一个可以输入整个上联来预测下联,这俩函数其实很类似。

from transformers import GPT2Model
from transformers import BertTokenizer
from transformers import GPT2Config
from datasets import load_from_disk
import torch.nn as nn
import torch.optim as optim
import time
import torch

1. 预测上下联

def predict1(inputs):

    # 任务模型
    model = Model()
    model.load_state_dict(torch.load('model/couplet-gpt2-19.bin'))
    model.eval()
    # 分词器
    tokenizer = BertTokenizer.from_pretrained('data/tokenizer-encode-tokenizer')
    # 输入编码
    encode_inputs = [tokenizer.convert_tokens_to_ids(inputs)]
    # 存储结果
    model_outputs = encode_inputs

    # 模型预测
    past_key_values = None
    # 预测长度
    max_output_length = 64

    for _ in range(max_output_length):

        # 将编码后的输入转换为张量
        model_inputs = torch.tensor(encode_inputs)
        with torch.no_grad():
            # 将输入张量送入模型计算
            outputs, past_key_values = model(model_inputs, past_key_values)

        # 获得最后一个输入字的隐藏状态
        y_pred = torch.argmax(outputs[-1], dim=-1).item()

        model_outputs.append(y_pred)

        if y_pred == 2:
            break

        encode_inputs = [y_pred]

    # 打印预测结果
    print(tokenizer.decode(model_outputs))

2. 只预测下联

def predict2(inputs):

    # 任务模型
    model = Model()
    model.load_state_dict(torch.load('model/couplet-gpt2-19.bin'))
    model.eval()
    # 分词器
    tokenizer = BertTokenizer.from_pretrained('data/tokenizer-encode-tokenizer')
    # 模型输入
    inputs = inputs + '[BRK]'
    encode_inputs = tokenizer.encode(inputs, add_special_tokens=False)
    # 保存结果
    model_outputs = encode_inputs

    # 模型预测
    past_key_values = None
    # 预测长度
    max_output_length = 64

    for _ in range(max_output_length):

        # 将编码后的输入转换为张量
        model_inputs = torch.tensor(encode_inputs)

        # 将输入张量送入模型计算
        with torch.no_grad():
            outputs, past_key_values = model(model_inputs, past_key_values)

        # 获得预测结果
        y_pred = torch.argmax(outputs[-1], dim=-1).item()
        # 存储预测结果
        model_outputs.append(y_pred)

        if y_pred == 2:
            break
        encode_inputs = [y_pred]

    # 打印预测结果
    print(tokenizer.decode(model_outputs))

3. 调用示例

if __name__ == '__main__':

    predict1('春')
    predict1('财')
    predict1('千')
    print('-' * 50)
    predict2('昨夜未开花等我')
    predict2('门对青山千古看')
    predict2('地久天长门有庆')

程序预测输出:

春 风 一 路 同 梅 醉 [BRK] 喜 鹊 千 枝 共 雪 欢 [END]
财 源 茂 盛 家 丁 旺 [BRK] 生 意 兴 隆 日 日 新 [END]
千 秋 功 业 辉 南 越 [BRK] 五 岭 风 云 壮 石 门 [END]
--------------------------------------------------
昨 夜 未 开 花 等 我 [BRK] 今 朝 将 落 酒 邀 谁 [END]
门 对 青 山 千 古 看 [BRK] 家 居 旺 地 四 时 新 [END]
地 久 天 长 门 有 庆 [BRK] 年 来 月 满 岁 无 烟 [END]

看着挺工整,懂对联的同学可以品品咋样。

未经允许不得转载:一亩三分地 » 基于 GPT 训练对联模型 – 模型预测