基于 GRU + Attention 生成对联 – 上联模型

上联生成模型比较简单,使用词嵌入层 + GRU + 线性层即可,其训练数据的构造如下:

输入: 雪 映 梅 花 梅 映 雪
目标: 映 梅 花 梅 映 雪 [EOS]

输入的数据和目标数据相差一个位置,即:输入前一个词预测后一个词。我们每个迭代就向网络中送入一条数据,并计算损失。

import random
import torch
import torch.nn as nn
import pickle
from torch.utils.data import DataLoader
import torch.optim as optim
import time
from tqdm import tqdm
import heapq
import numpy as np
import matplotlib.pyplot as plt
import math


# 计算设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

1. 词表类

class SingleVocab:
    """词表对象"""
    
    def __init__(self):
        vocab_single_path = 'data/vocab-single.pkl'
        vocab_data = pickle.load(open(vocab_single_path, 'rb'))
        self.word_to_index = vocab_data['word_to_index']
        self.index_to_word = vocab_data['index_to_word']
        self.vocab_size = vocab_data['vocab_size']
        self.EOS = self.word_to_index['[EOS]']

    def decode(self, input_ids):

        result = []
        for input_id in input_ids:
            word = self.index_to_word[input_id]
            result.append(word)
        return ''.join(result)


vocab_data = SingleVocab()

2. 数据类

class SingleDataset:
    """训练数据类"""

    def __init__(self):
        train_single_path = 'data/train-single.pkl'
        train_data = pickle.load(open(train_single_path, 'rb'))
        self.train_data = train_data['train_data']
        self.train_size = train_data['train_size']

    def __len__(self):
        return self.train_size

    def __getitem__(self, index):
        index = min(max(0, index), self.train_size - 1)
        single_data = self.train_data[index]
        x_train = torch.tensor(single_data[:-1], dtype=torch.int64, device=device)
        y_train = torch.tensor(single_data[1:],  dtype=torch.int64, device=device)

        return x_train, y_train

3. 模型类

class SingleModel(nn.Module):

    def __init__(self):
        super(SingleModel, self).__init__()
        self.ebd = nn.Embedding(num_embeddings=vocab_data.vocab_size, embedding_dim=128)
        self.gru = nn.GRU(input_size=128, hidden_size=256, batch_first=True)
        self.out = nn.Linear(in_features=256, out_features=vocab_data.vocab_size)

    def forward(self, inputs, hn):
        inputs = self.ebd(inputs)
        output, hn = self.gru(inputs, hn)
        output = self.out(output)

        return output, hn

    def h0(self):
        return torch.zeros(1, 1, 256, device=device)

4. 训练函数

def train():

    train_data = SingleDataset()

    # 数据加载
    dataloader = DataLoader(train_data, batch_size=1, shuffle=True)
    # 构建模型
    model = SingleModel().cuda(device)
    # 优化方法
    learning_rate = 2e-4
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # 损失函数
    criterion = torch.nn.CrossEntropyLoss()
    # 训练轮数
    epochs = 200
    # 损失变化
    loss_curve = []

    for epoch_idx in range(1, epochs + 1):

        total_loss = 0.0
        total_iter = 0
        start_time = time.time()
        progress_bar = tqdm(range(train_data.train_size), desc='epoch %d/%d' % (epoch_idx, epochs))

        for (x, y), _ in zip(dataloader, progress_bar):
            # 初始状态
            h0 = model.h0()
            # 网络计算
            y_pred, hn = model(x, h0)
            # 损失计算
            loss = criterion(y_pred.squeeze(), y.squeeze())
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()


            total_loss += (loss.item() * len(y.squeeze()))
            total_iter += len(y.squeeze())

        time.sleep(0.05)
        end_time = time.time()
        loss = total_loss / total_iter
        loss_curve.append(loss)
        print('loss: %4f perplexity: %4f time: %.2f' % (loss, math.exp(loss), end_time - start_time))


    # 存储模型
    model_save_info = {
        'epoch': epochs,
        'optimizer': optimizer.state_dict(),
        'lr': learning_rate,
        'loss': loss,
        'model': model.state_dict()
    }

    model_loss_info = {
        'epoch': epochs,
        'loss': loss_curve
    }

    torch.save(model_save_info, 'model/single-model.pth')
    torch.save(model_loss_info, 'model/single-model-loss.pth')


def loss_curve():
    model_loss = torch.load('model/single-model-loss.pth')
    plt.plot(range(model_loss['epoch']), model_loss['loss'])
    plt.title('Single Loss Curve')
    plt.grid()
    plt.show()

5. 预测函数

def predict(start_word):

    vocab_data = SingleVocab()
    model_data = torch.load('model/single-model.pth')
    model = SingleModel().cuda(device)
    model.load_state_dict(model_data['model'])

    # 模型预测
    with torch.no_grad():

        hn = model.h0()
        result = [vocab_data.word_to_index[start_word]]

        while True:
            inputs = torch.tensor([[result[-1]]], dtype=torch.int64, device=device)
            y_pred, hn = model(inputs, hn)
            y_pred = torch.argmax(y_pred.squeeze()).item()
            result.append(y_pred)

            if y_pred == vocab_data.EOS:
                break

    print(vocab_data.decode(result))

6. 程序入口

if __name__ == '__main__':
    train()
    loss_curve()
    predict('晚')
    predict('愿')
    predict('不')
    predict('常')
    predict('张')

7. 训练结果

经过 200 个 epoch 的训练,每个 epoch 大概需要 3 秒钟,训练损失变化如下图所示:

对 “晚”、”愿”、”不”、”常”、”张” 的上联预测结果如下:

晚风摇树树还挺[EOS]
愿景天成无墨迹[EOS]
不与时人争席地[EOS]
常赋江山大美[EOS]
张帆鼓浪行千里[EOS]

总体看起来效果还可以,具体的训练过程输出如下:

epoch 1/200: 100%|████████████████████████▉| 1999/2000 [00:03<00:00, 636.83it/s]
loss: 6.432335 perplexity: 621.623927 time: 3.19
epoch 2/200: 100%|████████████████████████▉| 1999/2000 [00:03<00:00, 632.13it/s]
loss: 5.780665 perplexity: 323.974468 time: 3.21
epoch 3/200: 100%|████████████████████████▉| 1999/2000 [00:03<00:00, 637.52it/s]
loss: 5.380104 perplexity: 217.044949 time: 3.19
epoch 4/200: 100%|████████████████████████▉| 1999/2000 [00:03<00:00, 623.76it/s]
loss: 4.955632 perplexity: 141.972268 time: 3.26
epoch 5/200: 100%|████████████████████████▉| 1999/2000 [00:03<00:00, 629.55it/s]
loss: 4.520249 perplexity: 91.858422 time: 3.23
epoch 6/200: 100%|████████████████████████▉| 1999/2000 [00:03<00:00, 629.16it/s]
loss: 4.094169 perplexity: 59.989455 time: 3.23
epoch 7/200: 100%|████████████████████████▉| 1999/2000 [00:03<00:00, 630.82it/s]
loss: 3.691508 perplexity: 40.105277 time: 3.22
epoch 8/200: 100%|████████████████████████▉| 1999/2000 [00:03<00:00, 634.26it/s]
loss: 3.303389 perplexity: 27.204667 time: 3.20
epoch 9/200: 100%|████████████████████████▉| 1999/2000 [00:03<00:00, 625.10it/s]
loss: 2.939015 perplexity: 18.897227 time: 3.25
epoch 10/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 631.02it/s]
loss: 2.602245 perplexity: 13.493993 time: 3.22
epoch 11/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 623.56it/s]
loss: 2.288203 perplexity: 9.857207 time: 3.26
epoch 12/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 631.67it/s]
loss: 1.997882 perplexity: 7.373421 time: 3.22
epoch 13/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 624.93it/s]
loss: 1.729811 perplexity: 5.639588 time: 3.25
epoch 14/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 630.77it/s]
loss: 1.490478 perplexity: 4.439216 time: 3.22
epoch 15/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 634.66it/s]
loss: 1.278635 perplexity: 3.591735 time: 3.20
epoch 16/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 624.20it/s]
loss: 1.093220 perplexity: 2.983865 time: 3.25
epoch 17/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 633.94it/s]
loss: 0.933750 perplexity: 2.544033 time: 3.20
epoch 18/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 632.48it/s]
loss: 0.798944 perplexity: 2.223193 time: 3.21
epoch 19/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 631.02it/s]
loss: 0.692281 perplexity: 1.998269 time: 3.22
epoch 20/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 634.73it/s]
loss: 0.606529 perplexity: 1.834054 time: 3.20
epoch 21/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 621.57it/s]
loss: 0.542258 perplexity: 1.719886 time: 3.27
epoch 22/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 633.48it/s]
loss: 0.493715 perplexity: 1.638392 time: 3.21
epoch 23/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 634.76it/s]
loss: 0.456420 perplexity: 1.578413 time: 3.20
epoch 24/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 631.26it/s]
loss: 0.430961 perplexity: 1.538736 time: 3.22
epoch 25/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 625.60it/s]
loss: 0.410036 perplexity: 1.506872 time: 3.25
epoch 26/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 629.56it/s]
loss: 0.393928 perplexity: 1.482794 time: 3.23
epoch 27/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 636.91it/s]
loss: 0.383314 perplexity: 1.467139 time: 3.19
epoch 28/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 632.67it/s]
loss: 0.373420 perplexity: 1.452694 time: 3.21
epoch 29/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 638.82it/s]
loss: 0.366879 perplexity: 1.443224 time: 3.18
epoch 30/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 633.69it/s]
loss: 0.359319 perplexity: 1.432354 time: 3.21
epoch 31/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 635.11it/s]
loss: 0.354548 perplexity: 1.425536 time: 3.20
epoch 32/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 628.70it/s]
loss: 0.350743 perplexity: 1.420122 time: 3.23
epoch 33/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 639.61it/s]
loss: 0.347951 perplexity: 1.416162 time: 3.18
epoch 34/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.97it/s]
loss: 0.344394 perplexity: 1.411134 time: 3.17
epoch 35/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 642.90it/s]
loss: 0.342793 perplexity: 1.408877 time: 3.16
epoch 36/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 639.81it/s]
loss: 0.339763 perplexity: 1.404614 time: 3.17
epoch 37/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 642.80it/s]
loss: 0.336566 perplexity: 1.400131 time: 3.16
epoch 38/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.65it/s]
loss: 0.336262 perplexity: 1.399706 time: 3.17
epoch 39/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 639.85it/s]
loss: 0.334987 perplexity: 1.397922 time: 3.17
epoch 40/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 641.60it/s]
loss: 0.333208 perplexity: 1.395437 time: 3.17
epoch 41/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 637.80it/s]
loss: 0.331910 perplexity: 1.393627 time: 3.18
epoch 42/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.21it/s]
loss: 0.330891 perplexity: 1.392208 time: 3.17
epoch 43/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 645.01it/s]
loss: 0.329723 perplexity: 1.390582 time: 3.15
epoch 44/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 631.31it/s]
loss: 0.329311 perplexity: 1.390009 time: 3.22
epoch 45/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.19it/s]
loss: 0.327033 perplexity: 1.386847 time: 3.17
epoch 46/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 636.74it/s]
loss: 0.326686 perplexity: 1.386366 time: 3.19
epoch 47/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 644.35it/s]
loss: 0.326353 perplexity: 1.385905 time: 3.15
epoch 48/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 639.27it/s]
loss: 0.324962 perplexity: 1.383978 time: 3.18
epoch 49/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 630.34it/s]
loss: 0.324677 perplexity: 1.383583 time: 3.22
epoch 50/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 641.42it/s]
loss: 0.325075 perplexity: 1.384135 time: 3.17
epoch 51/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 639.68it/s]
loss: 0.322915 perplexity: 1.381148 time: 3.18
epoch 52/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 643.88it/s]
loss: 0.323350 perplexity: 1.381748 time: 3.16
epoch 53/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 638.81it/s]
loss: 0.324574 perplexity: 1.383442 time: 3.18
epoch 54/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 636.42it/s]
loss: 0.323594 perplexity: 1.382087 time: 3.19
epoch 55/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 641.42it/s]
loss: 0.323098 perplexity: 1.381401 time: 3.17
epoch 56/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 633.38it/s]
loss: 0.320040 perplexity: 1.377183 time: 3.21
epoch 57/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 639.94it/s]
loss: 0.323073 perplexity: 1.381366 time: 3.17
epoch 58/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 643.20it/s]
loss: 0.321338 perplexity: 1.378971 time: 3.16
epoch 59/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 630.18it/s]
loss: 0.321702 perplexity: 1.379474 time: 3.22
epoch 60/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 638.74it/s]
loss: 0.321294 perplexity: 1.378911 time: 3.18
epoch 61/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 633.74it/s]
loss: 0.321074 perplexity: 1.378608 time: 3.21
epoch 62/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 635.37it/s]
loss: 0.322961 perplexity: 1.381211 time: 3.20
epoch 63/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.24it/s]
loss: 0.318891 perplexity: 1.375602 time: 3.17
epoch 64/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.22it/s]
loss: 0.320656 perplexity: 1.378031 time: 3.17
epoch 65/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 641.05it/s]
loss: 0.319825 perplexity: 1.376887 time: 3.17
epoch 66/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 630.26it/s]
loss: 0.320982 perplexity: 1.378481 time: 3.22
epoch 67/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 644.15it/s]
loss: 0.319009 perplexity: 1.375764 time: 3.15
epoch 68/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 641.31it/s]
loss: 0.319519 perplexity: 1.376465 time: 3.17
epoch 69/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 634.25it/s]
loss: 0.318731 perplexity: 1.375381 time: 3.20
epoch 70/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 641.27it/s]
loss: 0.319647 perplexity: 1.376642 time: 3.17
epoch 71/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 636.31it/s]
loss: 0.319307 perplexity: 1.376174 time: 3.19
epoch 72/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 638.65it/s]
loss: 0.319043 perplexity: 1.375811 time: 3.18
epoch 73/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 636.85it/s]
loss: 0.318372 perplexity: 1.374888 time: 3.19
epoch 74/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 638.42it/s]
loss: 0.318556 perplexity: 1.375141 time: 3.18
epoch 75/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 630.80it/s]
loss: 0.318014 perplexity: 1.374395 time: 3.22
epoch 76/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 637.60it/s]
loss: 0.317174 perplexity: 1.373241 time: 3.19
epoch 77/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 631.15it/s]
loss: 0.317120 perplexity: 1.373167 time: 3.22
epoch 78/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 641.35it/s]
loss: 0.317333 perplexity: 1.373460 time: 3.17
epoch 79/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 642.71it/s]
loss: 0.316914 perplexity: 1.372884 time: 3.16
epoch 80/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 638.43it/s]
loss: 0.316669 perplexity: 1.372548 time: 3.18
epoch 81/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 643.14it/s]
loss: 0.316532 perplexity: 1.372361 time: 3.16
epoch 82/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.99it/s]
loss: 0.316158 perplexity: 1.371847 time: 3.17
epoch 83/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.99it/s]
loss: 0.316418 perplexity: 1.372204 time: 3.17
epoch 84/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 634.36it/s]
loss: 0.316002 perplexity: 1.371634 time: 3.20
epoch 85/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 642.88it/s]
loss: 0.315444 perplexity: 1.370869 time: 3.16
epoch 86/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 644.89it/s]
loss: 0.319548 perplexity: 1.376505 time: 3.15
epoch 87/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 643.06it/s]
loss: 0.312315 perplexity: 1.366586 time: 3.16
epoch 88/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 632.38it/s]
loss: 0.314009 perplexity: 1.368901 time: 3.21
epoch 89/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 643.23it/s]
loss: 0.316672 perplexity: 1.372553 time: 3.16
epoch 90/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 638.10it/s]
loss: 0.315381 perplexity: 1.370782 time: 3.18
epoch 91/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 642.47it/s]
loss: 0.313921 perplexity: 1.368781 time: 3.16
epoch 92/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 641.13it/s]
loss: 0.315594 perplexity: 1.371074 time: 3.17
epoch 93/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 634.60it/s]
loss: 0.315773 perplexity: 1.371318 time: 3.20
epoch 94/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 638.64it/s]
loss: 0.316170 perplexity: 1.371864 time: 3.18
epoch 95/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 637.83it/s]
loss: 0.314085 perplexity: 1.369006 time: 3.18
epoch 96/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.03it/s]
loss: 0.314532 perplexity: 1.369618 time: 3.17
epoch 97/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 641.53it/s]
loss: 0.314977 perplexity: 1.370228 time: 3.17
epoch 98/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 638.31it/s]
loss: 0.315985 perplexity: 1.371609 time: 3.18
epoch 99/200: 100%|███████████████████████▉| 1999/2000 [00:03<00:00, 640.30it/s]
loss: 0.315175 perplexity: 1.370499 time: 3.17
epoch 100/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 630.23it/s]
loss: 0.314594 perplexity: 1.369703 time: 3.22
epoch 101/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 638.32it/s]
loss: 0.310872 perplexity: 1.364614 time: 3.18
epoch 102/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 640.98it/s]
loss: 0.315970 perplexity: 1.371589 time: 3.17
epoch 103/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 632.05it/s]
loss: 0.313312 perplexity: 1.367948 time: 3.21
epoch 104/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 635.79it/s]
loss: 0.315298 perplexity: 1.370667 time: 3.19
epoch 105/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 637.96it/s]
loss: 0.315189 perplexity: 1.370518 time: 3.18
epoch 106/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 642.77it/s]
loss: 0.316280 perplexity: 1.372015 time: 3.16
epoch 107/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 641.39it/s]
loss: 0.314321 perplexity: 1.369329 time: 3.17
epoch 108/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 639.93it/s]
loss: 0.314287 perplexity: 1.369283 time: 3.17
epoch 109/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 643.21it/s]
loss: 0.314460 perplexity: 1.369520 time: 3.16
epoch 110/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 645.78it/s]
loss: 0.314167 perplexity: 1.369118 time: 3.15
epoch 111/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 643.01it/s]
loss: 0.314302 perplexity: 1.369303 time: 3.16
epoch 112/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 641.90it/s]
loss: 0.313835 perplexity: 1.368664 time: 3.16
epoch 113/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 635.25it/s]
loss: 0.314017 perplexity: 1.368913 time: 3.20
epoch 114/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 643.30it/s]
loss: 0.313708 perplexity: 1.368490 time: 3.16
epoch 115/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 640.41it/s]
loss: 0.313603 perplexity: 1.368346 time: 3.17
epoch 116/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 641.54it/s]
loss: 0.313121 perplexity: 1.367687 time: 3.17
epoch 117/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 634.46it/s]
loss: 0.312999 perplexity: 1.367521 time: 3.20
epoch 118/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 642.05it/s]
loss: 0.313141 perplexity: 1.367714 time: 3.16
epoch 119/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 642.73it/s]
loss: 0.313562 perplexity: 1.368291 time: 3.16
epoch 120/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 639.97it/s]
loss: 0.313531 perplexity: 1.368247 time: 3.17
epoch 121/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 644.84it/s]
loss: 0.313099 perplexity: 1.367657 time: 3.15
epoch 122/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 641.89it/s]
loss: 0.312218 perplexity: 1.366452 time: 3.16
epoch 123/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 643.17it/s]
loss: 0.313157 perplexity: 1.367736 time: 3.16
epoch 124/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 643.18it/s]
loss: 0.311352 perplexity: 1.365269 time: 3.16
epoch 125/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 639.01it/s]
loss: 0.311111 perplexity: 1.364940 time: 3.18
epoch 126/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 644.94it/s]
loss: 0.309859 perplexity: 1.363233 time: 3.15
epoch 127/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 635.29it/s]
loss: 0.311036 perplexity: 1.364838 time: 3.20
epoch 128/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 641.33it/s]
loss: 0.310265 perplexity: 1.363786 time: 3.17
epoch 129/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 635.90it/s]
loss: 0.311719 perplexity: 1.365771 time: 3.19
epoch 130/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 639.11it/s]
loss: 0.312033 perplexity: 1.366200 time: 3.18
epoch 131/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 643.89it/s]
loss: 0.312288 perplexity: 1.366549 time: 3.16
epoch 132/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 632.76it/s]
loss: 0.310948 perplexity: 1.364718 time: 3.21
epoch 133/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 641.51it/s]
loss: 0.311087 perplexity: 1.364907 time: 3.17
epoch 134/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 642.98it/s]
loss: 0.312476 perplexity: 1.366805 time: 3.16
epoch 135/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 643.90it/s]
loss: 0.313293 perplexity: 1.367922 time: 3.16
epoch 136/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 640.55it/s]
loss: 0.313457 perplexity: 1.368147 time: 3.17
epoch 137/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 641.65it/s]
loss: 0.312856 perplexity: 1.367325 time: 3.17
epoch 138/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 639.91it/s]
loss: 0.312833 perplexity: 1.367293 time: 3.17
epoch 139/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 643.16it/s]
loss: 0.312676 perplexity: 1.367078 time: 3.16
epoch 140/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 641.62it/s]
loss: 0.311934 perplexity: 1.366065 time: 3.17
epoch 141/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 635.12it/s]
loss: 0.312199 perplexity: 1.366426 time: 3.20
epoch 142/200: 100%|██████████████████████▉| 1999/2000 [00:03<00:00, 578.61it/s]
loss: 0.311822 perplexity: 1.365911 time: 3.51
epoch 143/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 418.35it/s]
loss: 0.312174 perplexity: 1.366393 time: 4.83
epoch 144/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 422.59it/s]
loss: 0.309614 perplexity: 1.362899 time: 4.78
epoch 145/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 417.97it/s]
loss: 0.310990 perplexity: 1.364775 time: 4.83
epoch 146/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 423.15it/s]
loss: 0.309620 perplexity: 1.362907 time: 4.77
epoch 147/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 417.29it/s]
loss: 0.310813 perplexity: 1.364534 time: 4.84
epoch 148/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 422.26it/s]
loss: 0.310956 perplexity: 1.364729 time: 4.78
epoch 149/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.30it/s]
loss: 0.311995 perplexity: 1.366147 time: 4.85
epoch 150/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 420.83it/s]
loss: 0.311527 perplexity: 1.365509 time: 4.80
epoch 151/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.15it/s]
loss: 0.312348 perplexity: 1.366630 time: 4.85
epoch 152/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 421.26it/s]
loss: 0.311857 perplexity: 1.365959 time: 4.80
epoch 153/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.81it/s]
loss: 0.311749 perplexity: 1.365812 time: 4.85
epoch 154/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 421.42it/s]
loss: 0.311018 perplexity: 1.364814 time: 4.79
epoch 155/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.36it/s]
loss: 0.312835 perplexity: 1.367296 time: 4.85
epoch 156/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 421.39it/s]
loss: 0.309693 perplexity: 1.363007 time: 4.79
epoch 157/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.13it/s]
loss: 0.310910 perplexity: 1.364667 time: 4.85
epoch 158/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 420.95it/s]
loss: 0.311337 perplexity: 1.365249 time: 4.80
epoch 159/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.40it/s]
loss: 0.311718 perplexity: 1.365770 time: 4.85
epoch 160/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 420.80it/s]
loss: 0.310479 perplexity: 1.364079 time: 4.80
epoch 161/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.38it/s]
loss: 0.310993 perplexity: 1.364779 time: 4.85
epoch 162/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 421.10it/s]
loss: 0.310943 perplexity: 1.364711 time: 4.80
epoch 163/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 417.27it/s]
loss: 0.310519 perplexity: 1.364133 time: 4.84
epoch 164/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 422.17it/s]
loss: 0.311127 perplexity: 1.364962 time: 4.79
epoch 165/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 417.23it/s]
loss: 0.309962 perplexity: 1.363373 time: 4.84
epoch 166/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 421.93it/s]
loss: 0.309264 perplexity: 1.362422 time: 4.79
epoch 167/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.87it/s]
loss: 0.309826 perplexity: 1.363188 time: 4.85
epoch 168/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 421.55it/s]
loss: 0.310204 perplexity: 1.363703 time: 4.79
epoch 169/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.22it/s]
loss: 0.311111 perplexity: 1.364941 time: 4.85
epoch 170/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 420.31it/s]
loss: 0.310690 perplexity: 1.364366 time: 4.81
epoch 171/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 415.69it/s]
loss: 0.310723 perplexity: 1.364412 time: 4.86
epoch 172/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 420.58it/s]
loss: 0.310778 perplexity: 1.364487 time: 4.80
epoch 173/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 415.61it/s]
loss: 0.310555 perplexity: 1.364182 time: 4.86
epoch 174/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 420.41it/s]
loss: 0.310340 perplexity: 1.363888 time: 4.81
epoch 175/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 415.37it/s]
loss: 0.309885 perplexity: 1.363268 time: 4.86
epoch 176/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 420.47it/s]
loss: 0.307662 perplexity: 1.360241 time: 4.80
epoch 177/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 415.83it/s]
loss: 0.309754 perplexity: 1.363090 time: 4.86
epoch 178/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 421.29it/s]
loss: 0.309925 perplexity: 1.363323 time: 4.80
epoch 179/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.10it/s]
loss: 0.308505 perplexity: 1.361389 time: 4.85
epoch 180/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 420.56it/s]
loss: 0.309475 perplexity: 1.362710 time: 4.80
epoch 181/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 414.84it/s]
loss: 0.311049 perplexity: 1.364856 time: 4.87
epoch 182/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 419.95it/s]
loss: 0.310470 perplexity: 1.364066 time: 4.81
epoch 183/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.51it/s]
loss: 0.310299 perplexity: 1.363833 time: 4.85
epoch 184/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 420.52it/s]
loss: 0.310061 perplexity: 1.363508 time: 4.80
epoch 185/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.14it/s]
loss: 0.312140 perplexity: 1.366345 time: 4.85
epoch 186/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 419.51it/s]
loss: 0.309345 perplexity: 1.362533 time: 4.82
epoch 187/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 415.50it/s]
loss: 0.310135 perplexity: 1.363610 time: 4.86
epoch 188/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 419.07it/s]
loss: 0.310015 perplexity: 1.363446 time: 4.82
epoch 189/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.75it/s]
loss: 0.310180 perplexity: 1.363671 time: 4.85
epoch 190/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 418.45it/s]
loss: 0.309288 perplexity: 1.362454 time: 4.83
epoch 191/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 414.58it/s]
loss: 0.311163 perplexity: 1.365012 time: 4.87
epoch 192/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 417.22it/s]
loss: 0.306858 perplexity: 1.359148 time: 4.84
epoch 193/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 414.43it/s]
loss: 0.307514 perplexity: 1.360040 time: 4.87
epoch 194/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.68it/s]
loss: 0.308191 perplexity: 1.360961 time: 4.85
epoch 195/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.08it/s]
loss: 0.308857 perplexity: 1.361867 time: 4.85
epoch 196/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 417.43it/s]
loss: 0.309727 perplexity: 1.363052 time: 4.84
epoch 197/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.41it/s]
loss: 0.309105 perplexity: 1.362206 time: 4.85
epoch 198/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 415.85it/s]
loss: 0.310670 perplexity: 1.364339 time: 4.86
epoch 199/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.28it/s]
loss: 0.308429 perplexity: 1.361285 time: 4.85
epoch 200/200: 100%|██████████████████████▉| 1999/2000 [00:04<00:00, 416.01it/s]
loss: 0.309753 perplexity: 1.363089 time: 4.86
未经允许不得转载:一亩三分地 » 基于 GRU + Attention 生成对联 – 上联模型