我们主要实现了两个预测函数,一个用于给定第一个字,来预测上联和下联,另外一个可以输入整个上联来预测下联,这俩函数其实很类似。
from transformers import GPT2Model from transformers import BertTokenizer from transformers import GPT2Config from datasets import load_from_disk import torch.nn as nn import torch.optim as optim import time import torch
1. 预测上下联
def predict1(inputs): # 任务模型 model = Model() model.load_state_dict(torch.load('model/couplet-gpt2-19.bin')) model.eval() # 分词器 tokenizer = BertTokenizer.from_pretrained('data/tokenizer-encode-tokenizer') # 输入编码 encode_inputs = [tokenizer.convert_tokens_to_ids(inputs)] # 存储结果 model_outputs = encode_inputs # 模型预测 past_key_values = None # 预测长度 max_output_length = 64 for _ in range(max_output_length): # 将编码后的输入转换为张量 model_inputs = torch.tensor(encode_inputs) with torch.no_grad(): # 将输入张量送入模型计算 outputs, past_key_values = model(model_inputs, past_key_values) # 获得最后一个输入字的隐藏状态 y_pred = torch.argmax(outputs[-1], dim=-1).item() model_outputs.append(y_pred) if y_pred == 2: break encode_inputs = [y_pred] # 打印预测结果 print(tokenizer.decode(model_outputs))
2. 只预测下联
def predict2(inputs): # 任务模型 model = Model() model.load_state_dict(torch.load('model/couplet-gpt2-19.bin')) model.eval() # 分词器 tokenizer = BertTokenizer.from_pretrained('data/tokenizer-encode-tokenizer') # 模型输入 inputs = inputs + '[BRK]' encode_inputs = tokenizer.encode(inputs, add_special_tokens=False) # 保存结果 model_outputs = encode_inputs # 模型预测 past_key_values = None # 预测长度 max_output_length = 64 for _ in range(max_output_length): # 将编码后的输入转换为张量 model_inputs = torch.tensor(encode_inputs) # 将输入张量送入模型计算 with torch.no_grad(): outputs, past_key_values = model(model_inputs, past_key_values) # 获得预测结果 y_pred = torch.argmax(outputs[-1], dim=-1).item() # 存储预测结果 model_outputs.append(y_pred) if y_pred == 2: break encode_inputs = [y_pred] # 打印预测结果 print(tokenizer.decode(model_outputs))
3. 调用示例
if __name__ == '__main__': predict1('春') predict1('财') predict1('千') print('-' * 50) predict2('昨夜未开花等我') predict2('门对青山千古看') predict2('地久天长门有庆')
程序预测输出:
春 风 一 路 同 梅 醉 [BRK] 喜 鹊 千 枝 共 雪 欢 [END] 财 源 茂 盛 家 丁 旺 [BRK] 生 意 兴 隆 日 日 新 [END] 千 秋 功 业 辉 南 越 [BRK] 五 岭 风 云 壮 石 门 [END] -------------------------------------------------- 昨 夜 未 开 花 等 我 [BRK] 今 朝 将 落 酒 邀 谁 [END] 门 对 青 山 千 古 看 [BRK] 家 居 旺 地 四 时 新 [END] 地 久 天 长 门 有 庆 [BRK] 年 来 月 满 岁 无 烟 [END]
看着挺工整,懂对联的同学可以品品咋样。