下联模型采用 Seq2Seq + Attention 结构。其计算过程如下:
- 首先,将上联送入编码器得到对上联的语义理解;
- 然后,将下联送入解码器得到对下联的语义理解;
- 接着,计算解码器每一个时间步的输出和编码器各个时间步的注意力张量,并将该张量拼接到解码器的每个时间步输出;
- 最后,将包含了注意力信息的张量送入线性层完成预测。
上面结构中,编码器和解码器使用的网络结构是完全相同的,即:使用同一个模型类实例化出 encoder 和 decoder,这样训练速度会快很多。
原来思考的是,每次将输入数据送入解码器之前,先进行注意力计算,这样发现训练速度会慢不少,接着改成先输入解码器,然后进行注意力计算,速度大概提升了 1.5 倍,模型的效果也还不错。
import numpy as np import torch import torch.nn as nn import pickle from torch.utils.data import DataLoader import torch.optim as optim import time from tqdm import tqdm import matplotlib.pyplot as plt import math import pandas as pd import heapq # 定义计算设备 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
1. 词表类
class DoubleVocab: def __init__(self): vocab_data = pickle.load(open('data/vocab-double.pkl', 'rb')) self.word_to_index = vocab_data['word_to_index'] self.index_to_word = vocab_data['index_to_word'] self.vocab_size = vocab_data['vocab_size'] self.EOS = self.word_to_index['[EOS]'] self.SOS = self.word_to_index['[SOS]'] def decode(self, input_ids): result = [] for input_id in input_ids: word = self.index_to_word[input_id] result.append(word) return ''.join(result) vocab_data = DoubleVocab()
2. 数据类
class DoubleDataset: def __init__(self): train_double_path = 'data/train-double.pkl' train_data = pickle.load(open(train_double_path, 'rb')) self.train_data = train_data['train_data'] self.train_size = train_data['train_size'] def __len__(self): return self.train_size def __getitem__(self, index): index = min(max(0, index), self.train_size - 1) encoder_inputs, decoder_data = self.train_data[index] decoder_inputs, decoder_true = decoder_data[:-1], decoder_data[1:] # 转换为张量并指定计算设备 encoder_inputs = torch.tensor(encoder_inputs, device=device) decoder_inputs = torch.tensor(decoder_inputs, device=device) decoder_true = torch.tensor(decoder_true, device=device) return encoder_inputs, decoder_inputs, decoder_true
3. 基础模型
我们前面提到,编码器和解码器使用的结构是完全相同,所以就封装了 BaseModel 类来构建编码器和解码器,实现对上联和下联的语义理解。
class BaseModel(nn.Module): def __init__(self): super(BaseModel, self).__init__() self.ebd = nn.Embedding(num_embeddings=vocab_data.vocab_size, embedding_dim=256) self.gru = nn.GRU(input_size=256, hidden_size=256, batch_first=True) def forward(self, inputs, hn): # 对上联或下联数据进行词嵌入 inputs = self.ebd(inputs) # 将压缩后的上联送入 GRU 层计算 outputs, hn = self.gru(inputs, hn) return outputs, hn def h0(self): return torch.zeros(1, 1, 256, device=device)
4. 注意力类
注意力类原理很简单,query 为解码器的每一个时间步的输出,value 为编码器所有时间步的输出,此处需要注意的是计算过程中数据形状的变化,以及广播机制的应用。
class Attention(nn.Module): def __init__(self): super(Attention, self).__init__() # 将解码器上一时间步隐藏状态形状调整为注意力张量的形状 self.query_weight = nn.Linear(in_features=256, out_features=256) # 将编码器每一个时间步张量形状调整为注意力张量形状 self.value_weight = nn.Linear(in_features=256, out_features=256) # 计算 query 和 value 的注意力分数,编码器有几个时间步就有几个值输出 self.score_weight = nn.Linear(in_features=256, out_features=1) def forward(self, encoder_outputs, decoder_outputs): """ :param decoder_inputs: 解码器当前时间步输入 :param encoder_outputs: 编码器每个时间步输出 :param decoder_hn: 解码器初始隐藏状态 :return: 添加注意力表示后的解码器输入张量 """ # decoder_outputs 的形状 (1, 8, 256), 经过 query_weight 计算后变为 (1, 8, 128) query = self.query_weight(decoder_outputs) # encoder_outputs 的形状 (1, 7, 256), 经过 value_weight 计算后变为 (1, 7, 128) value = self.value_weight(encoder_outputs) # 因为我们想让 query 中每一个 128 张量加到 value 的每一个张量中 # 所以,先将 value 的形状从 (1, 7, 128) 变为 (7, 1, 128) # 利用广播机 (7, 1, 128) + (1, 8, 128) = (7, 8, 128) value = value.permute(1, 0, 2) # score_weight 将维度从 (7, 8, 128) 转换为 (7, 8, 1) # 此时,我们得到 8 个 (7, 1) 的注意力向量值,相当于解码器每个时间步的输出对应编码器的注意力张量值 score = self.score_weight(torch.tanh(value + query)) # 计算注意力权重分布 # 我们将 (7, 8, 1) 的第 0 个维度进行归一化,得到 8 个总和分别为1的概率表示 attention_weight = torch.softmax(score, dim=0) # 计算注意力张量表示 # attention_weight 的形状为 (7, 8, 1) # encoder_outputs 的形状为 (1, 7, 256) # 先将 encoder_outputs 的形状变为 (7, 1, 256) # 然后进行 bmm 运算,即: (7, 8, 1) * (7, 1, 256) = (7, 8, 256) # 再将 8 个 (7, 256) 中的每个进行相加,即: 7 个 (256,) 相加,最终得到 (8, 256) # 该数据表示得到的 8 个时间步对应的注意力张量表示 encoder_outputs = encoder_outputs.permute(1, 0, 2) attention_tensor = torch.sum(torch.bmm(attention_weight, encoder_outputs), dim=0) # 注意力张量和输入张量拼接 # attention_tensor 的形状为 (8, 256) # decoder_outputs 的形状为 (1, 8, 256) # 为了能够横向拼接,我们先把 attention_tensor 的形状扩展为 (1, 8, 256) # 然后再按照最后的一个维度进行拼接,得到 (1, 8, 512) # 该数据即表示 8 个时间步添加了注意力张量后的张量表示 decoder_inputs = torch.cat([decoder_outputs, attention_tensor.unsqueeze(0)], dim=-1) return decoder_inputs
5. 模型类
该类除了提供前向计算函数,还封装了两个预测函数 predict 和 predict2,其中 predict 就是简单的根据上一个时间步的预测结果作为下一个时间步的输入来得到预测结果,而 predict2 则使用 k=3 的 beam search 算法来优化输出结果,从几次测试结果来看,beam search 优化后的结果更加接近原始的训练语料。
class DoubleModel(nn.Module): def __init__(self): super(DoubleModel, self).__init__() self.vocab_size = vocab_data.vocab_size self.encoder = BaseModel() self.decoder = BaseModel() self.attncal = Attention() self.outputs = nn.Linear(in_features=512, out_features=self.vocab_size) def forward(self, encoder_inputs, decoder_inputs): # 1. 上联输入编码器提取语义 encoder_h0 = self.encoder.h0() encoder_outputs, encoder_hn = self.encoder(encoder_inputs, encoder_h0) # 2. 下联输入编码器提取语义 deocder_h0 = self.decoder.h0() decoder_outputs, decoder_hn = self.decoder(decoder_inputs, deocder_h0) # 3. 计算解码器每一个隐藏状态的注意力张量 # decoder_outputs 形状 torch.Size([1, 8, 256]) # encoder_outputs 形状 torch.Size([1, 7, 256]) inputs = self.attncal(encoder_outputs, decoder_outputs) # 4. 将包含了注意力值的各个时间步的输出送入线性层得到输出 # 输出的形状为 (1, 8, 1098) result = self.outputs(inputs) return result def predict(self, single): with torch.no_grad(): # 1. 对输入进行索引化 encoder_inputs = [] for word in single: encoder_inputs.append(vocab_data.word_to_index[word]) encoder_inputs = torch.tensor([encoder_inputs], device=device) # 2. 上联送入编码器提取语义 encoder_hn = self.encoder.h0() encoder_outputs, hn = self.encoder(encoder_inputs, encoder_hn) # 3. 初始化解码器初始输入 outputs = [vocab_data.SOS] for _ in range(encoder_outputs.shape[1]): # 3.1 构建输出输入 SOS 张量,维度是 (1, 1) decoder_inputs = torch.tensor([[outputs[-1]]], device=device) # 3.2 将输入张量送入到编码器中,得到每一个时间步的输出,形状为 (1, 1, 256) decoder_outputs, hn = self.decoder(decoder_inputs, hn) # 3.3 计算该时间步输出和编码器各个时间步输出的注意力张量表示 decoder_inputs = self.attncal(encoder_outputs, decoder_outputs) # 3.4 将添加了注意力张量的数据送入到线性层得到预测结果 result = self.outputs(decoder_inputs) outputs.append(torch.argmax(result.squeeze(), dim=0).item()) # 4. 返回预测上下联 return single, vocab_data.decode(outputs) def predict2(self, single): with torch.no_grad(): # 1. 对输入进行索引化 encoder_inputs = [] for word in single: encoder_inputs.append(vocab_data.word_to_index[word]) encoder_inputs = torch.tensor([encoder_inputs], device=device) # 2. 上联送入编码器提取语义 encoder_hn = self.encoder.h0() encoder_outputs, hn = self.encoder(encoder_inputs, encoder_hn) container = [] heapq.heappush(container, [1, False, [vocab_data.SOS], hn]) while True: new_container = [] for (proba, is_end, seq_list, hn) in container: if is_end: heapq.heappush(new_container, [proba, is_end, seq_list, hn]) continue decoder_outputs = torch.tensor([[seq_list[-1]]], device=device) decoder_outputs, hn = self.decoder(decoder_outputs, hn) decoder_inputs = self.attncal(encoder_outputs, decoder_outputs) result = self.outputs(decoder_inputs) result = torch.softmax(result.squeeze(), dim=0) # 获得预测前 k 个结果 values, indexes = torch.topk(result, 3) for index, value in zip(indexes, values): flag = True if index.item() == vocab_data.EOS else False nseq = [] nseq.extend(seq_list) nseq.append(index.item()) nproba = value.item() * proba if len(new_container) > 3: heapq.heapreplace(new_container, [nproba, flag, nseq, hn]) else: heapq.heappush(new_container, [nproba, flag, nseq, hn]) proba, is_end, seq_list, hn = max(new_container) if len(seq_list) == (len(single) + 1) or is_end == True: return single, vocab_data.decode(seq_list) container = new_container
6. 训练函数
def train(): # 数据加载 train_data = DoubleDataset() dataloader = DataLoader(train_data, shuffle=True, batch_size=1) # 模型构建 model = DoubleModel().cuda(device) # 损失函数 criterion = nn.CrossEntropyLoss() # 优化方法 lr = 3e-4 optimizer = optim.AdamW(model.parameters(), lr=lr) # 训练轮数 epochs = 200 # 损失变化 loss_curve = [] for epoch_idx in range(1, epochs + 1): total_loss = 0.0 total_iter = 0 word_num = 0 start_time = time.time() progress_bar = tqdm(range(train_data.train_size), desc='epoch: %d/%d' % (epoch_idx, epochs)) for (encoder_inputs, decoder_inputs, decoder_true), _ in zip(dataloader, progress_bar): # 模型计算 y_pred = model(encoder_inputs, decoder_inputs) # 计算损失 loss = criterion(y_pred.squeeze(), decoder_true.squeeze()) # 梯度清零 optimizer.zero_grad() # 反向传播 loss.backward() # 参数更新 optimizer.step() total_loss += loss.item() * len(decoder_inputs.squeeze()) total_iter += len(decoder_inputs.squeeze()) end_time = time.time() time.sleep(0.1) loss = total_loss / total_iter loss_curve.append(loss) print('loss: %.4f perplexity: %.4f time: %.2f' % (loss, math.exp(loss), end_time - start_time)) # 存储训练信息 model_save_info = { 'lr': lr, 'epoch': epochs, 'optimizer': optimizer.state_dict(), 'loss': loss, 'model': model.state_dict() } model_loss_info = { 'epoch': epochs, 'loss': loss_curve } torch.save(model_save_info, 'model/double-model.pth') torch.save(model_loss_info, 'model/double-model-loss.pth') def loss_curve(): model_loss = torch.load('model/double-model-loss.pth') plt.plot(range(model_loss['epoch']), model_loss['loss']) plt.title('Double Loss Curve') plt.grid() plt.show()
7. 预测函数
def predict(single): # 加载模型数据 model_data = torch.load('model/double-model.pth') # 初始化预测模型 model = DoubleModel().cuda(device) model.load_state_dict(model_data['model']) # 输出模型结果 # single, double = model.predict(single) single, double = model.predict2(single) return single, double
8. 训练结果
一共经过了 200 个 epoch 的训练,每个 epoch 需要 7 秒左右的时间,最终的损失如下:
loss: 0.0070 perplexity: 1.0070 time: 7.91
损失的变化如下图所示:
将上联的输出作为输入,预测下联:
if __name__ == '__main__': # train() # loss_curve() single, double = predict('晚风摇树树还挺') print(single, double) single, double = predict('愿景天成无墨迹') print(single, double) single, double = predict('不与时人争席地') print(single, double) single, double = predict('常赋江山大美') print(single, double) single, double = predict('张帆鼓浪行千里') print(single, double)
预测结果如下:
晚风摇树树还挺 [SOS]晨露润花花更红 愿景天成无墨迹 [SOS]万方乐奏有于阗 不与时人争席地 [SOS]只将长剑笑昆仑 常赋江山大美 [SOS]每予岁月更新 张帆鼓浪行千里 [SOS]过海飘洋胜八仙
效果看起来也还不错,具体的训练输出如下:
epoch: 1/200: 100%|███████████████████████▉| 1999/2000 [00:09<00:00, 209.34it/s] loss: 6.4622 perplexity: 640.4755 time: 9.55 epoch: 2/200: 100%|███████████████████████▉| 1999/2000 [00:09<00:00, 209.34it/s] loss: 5.4442 perplexity: 231.4176 time: 9.55 epoch: 3/200: 100%|███████████████████████▉| 1999/2000 [00:09<00:00, 209.14it/s] loss: 4.4315 perplexity: 84.0603 time: 9.56 epoch: 4/200: 100%|███████████████████████▉| 1999/2000 [00:09<00:00, 208.47it/s] loss: 3.3452 perplexity: 28.3673 time: 9.59 epoch: 5/200: 100%|███████████████████████▉| 1999/2000 [00:09<00:00, 208.61it/s] loss: 2.3840 perplexity: 10.8484 time: 9.58 epoch: 6/200: 100%|███████████████████████▉| 1999/2000 [00:09<00:00, 208.78it/s] loss: 1.6355 perplexity: 5.1318 time: 9.57 epoch: 7/200: 100%|███████████████████████▉| 1999/2000 [00:09<00:00, 208.67it/s] loss: 1.0949 perplexity: 2.9890 time: 9.58 epoch: 8/200: 100%|███████████████████████▉| 1999/2000 [00:09<00:00, 208.57it/s] loss: 0.7263 perplexity: 2.0674 time: 9.58 epoch: 9/200: 100%|███████████████████████▉| 1999/2000 [00:09<00:00, 208.57it/s] loss: 0.4925 perplexity: 1.6364 time: 9.58 epoch: 10/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.60it/s] loss: 0.3400 perplexity: 1.4049 time: 9.58 epoch: 11/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 209.06it/s] loss: 0.2390 perplexity: 1.2700 time: 9.56 epoch: 12/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.96it/s] loss: 0.1816 perplexity: 1.1992 time: 9.57 epoch: 13/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.72it/s] loss: 0.1403 perplexity: 1.1506 time: 9.58 epoch: 14/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.36it/s] loss: 0.1124 perplexity: 1.1190 time: 9.59 epoch: 15/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.29it/s] loss: 0.0874 perplexity: 1.0914 time: 9.60 epoch: 16/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.21it/s] loss: 0.0742 perplexity: 1.0770 time: 9.60 epoch: 17/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.16it/s] loss: 0.0614 perplexity: 1.0633 time: 9.60 epoch: 18/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.46it/s] loss: 0.0550 perplexity: 1.0565 time: 9.59 epoch: 19/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.44it/s] loss: 0.0542 perplexity: 1.0557 time: 9.59 epoch: 20/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 207.90it/s] loss: 0.0408 perplexity: 1.0417 time: 9.62 epoch: 21/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.70it/s] loss: 0.0404 perplexity: 1.0412 time: 9.58 epoch: 22/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.50it/s] loss: 0.0400 perplexity: 1.0408 time: 9.59 epoch: 23/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.27it/s] loss: 0.0338 perplexity: 1.0344 time: 9.60 epoch: 24/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.55it/s] loss: 0.0316 perplexity: 1.0321 time: 9.59 epoch: 25/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 207.76it/s] loss: 0.0313 perplexity: 1.0318 time: 9.62 epoch: 26/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 207.67it/s] loss: 0.0271 perplexity: 1.0275 time: 9.63 epoch: 27/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 208.29it/s] loss: 0.0236 perplexity: 1.0239 time: 9.60 epoch: 28/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 207.97it/s] loss: 0.0217 perplexity: 1.0219 time: 9.61 epoch: 29/200: 100%|██████████████████████▉| 1999/2000 [00:09<00:00, 207.92it/s] loss: 0.0175 perplexity: 1.0177 time: 9.61 epoch: 30/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 250.45it/s] loss: 0.0210 perplexity: 1.0212 time: 7.98 epoch: 31/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 250.98it/s] loss: 0.0157 perplexity: 1.0159 time: 7.97 epoch: 32/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 251.11it/s] loss: 0.0183 perplexity: 1.0185 time: 7.96 epoch: 33/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 251.18it/s] loss: 0.0179 perplexity: 1.0180 time: 7.96 epoch: 34/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.09it/s] loss: 0.0111 perplexity: 1.0112 time: 7.93 epoch: 35/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.53it/s] loss: 0.0100 perplexity: 1.0101 time: 7.92 epoch: 36/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.91it/s] loss: 0.0084 perplexity: 1.0085 time: 7.90 epoch: 37/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.88it/s] loss: 0.0170 perplexity: 1.0171 time: 7.91 epoch: 38/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.09it/s] loss: 0.0137 perplexity: 1.0138 time: 7.90 epoch: 39/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.81it/s] loss: 0.0097 perplexity: 1.0098 time: 7.91 epoch: 40/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.25it/s] loss: 0.0082 perplexity: 1.0082 time: 7.89 epoch: 41/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.18it/s] loss: 0.0088 perplexity: 1.0088 time: 7.90 epoch: 42/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.30it/s] loss: 0.0102 perplexity: 1.0103 time: 7.89 epoch: 43/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.57it/s] loss: 0.0112 perplexity: 1.0113 time: 7.92 epoch: 44/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.51it/s] loss: 0.0141 perplexity: 1.0142 time: 7.89 epoch: 45/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.87it/s] loss: 0.0067 perplexity: 1.0067 time: 7.87 epoch: 46/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.84it/s] loss: 0.0055 perplexity: 1.0055 time: 7.91 epoch: 47/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.94it/s] loss: 0.0097 perplexity: 1.0097 time: 7.90 epoch: 48/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.93it/s] loss: 0.0075 perplexity: 1.0076 time: 7.90 epoch: 49/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.97it/s] loss: 0.0074 perplexity: 1.0074 time: 7.90 epoch: 50/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.02it/s] loss: 0.0094 perplexity: 1.0094 time: 7.90 epoch: 51/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.23it/s] loss: 0.0078 perplexity: 1.0078 time: 7.89 epoch: 52/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.11it/s] loss: 0.0058 perplexity: 1.0058 time: 7.90 epoch: 53/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.12it/s] loss: 0.0068 perplexity: 1.0068 time: 7.90 epoch: 54/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.75it/s] loss: 0.0085 perplexity: 1.0086 time: 7.91 epoch: 55/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.58it/s] loss: 0.0123 perplexity: 1.0124 time: 7.91 epoch: 56/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.00it/s] loss: 0.0085 perplexity: 1.0085 time: 7.90 epoch: 57/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.10it/s] loss: 0.0051 perplexity: 1.0051 time: 7.90 epoch: 58/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.70it/s] loss: 0.0074 perplexity: 1.0074 time: 7.88 epoch: 59/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.71it/s] loss: 0.0079 perplexity: 1.0079 time: 7.91 epoch: 60/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.53it/s] loss: 0.0072 perplexity: 1.0073 time: 7.88 epoch: 61/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.85it/s] loss: 0.0023 perplexity: 1.0023 time: 7.91 epoch: 62/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.75it/s] loss: 0.0046 perplexity: 1.0046 time: 7.91 epoch: 63/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.02it/s] loss: 0.0023 perplexity: 1.0023 time: 7.90 epoch: 64/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.08it/s] loss: 0.0039 perplexity: 1.0039 time: 7.90 epoch: 65/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.77it/s] loss: 0.0031 perplexity: 1.0031 time: 7.91 epoch: 66/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.67it/s] loss: 0.0060 perplexity: 1.0060 time: 7.91 epoch: 67/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.48it/s] loss: 0.0063 perplexity: 1.0063 time: 7.89 epoch: 68/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.82it/s] loss: 0.0054 perplexity: 1.0054 time: 7.91 epoch: 69/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.93it/s] loss: 0.0072 perplexity: 1.0073 time: 7.90 epoch: 70/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.69it/s] loss: 0.0022 perplexity: 1.0022 time: 7.91 epoch: 71/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.18it/s] loss: 0.0012 perplexity: 1.0012 time: 7.90 epoch: 72/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.04it/s] loss: 0.0058 perplexity: 1.0058 time: 7.90 epoch: 73/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.32it/s] loss: 0.0088 perplexity: 1.0088 time: 7.89 epoch: 74/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.82it/s] loss: 0.0073 perplexity: 1.0074 time: 7.91 epoch: 75/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 252.82it/s] loss: 0.0052 perplexity: 1.0053 time: 7.91 epoch: 76/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.29it/s] loss: 0.0045 perplexity: 1.0046 time: 7.89 epoch: 77/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.74it/s] loss: 0.0062 perplexity: 1.0062 time: 7.88 epoch: 78/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.91it/s] loss: 0.0062 perplexity: 1.0062 time: 7.87 epoch: 79/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.86it/s] loss: 0.0074 perplexity: 1.0074 time: 7.87 epoch: 80/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 253.72it/s] loss: 0.0057 perplexity: 1.0057 time: 7.88 epoch: 81/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.00it/s] loss: 0.0033 perplexity: 1.0033 time: 7.87 epoch: 82/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.74it/s] loss: 0.0046 perplexity: 1.0046 time: 7.85 epoch: 83/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.20it/s] loss: 0.0045 perplexity: 1.0045 time: 7.86 epoch: 84/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.94it/s] loss: 0.0091 perplexity: 1.0091 time: 7.84 epoch: 85/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.35it/s] loss: 0.0032 perplexity: 1.0032 time: 7.86 epoch: 86/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.51it/s] loss: 0.0026 perplexity: 1.0026 time: 7.85 epoch: 87/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 255.37it/s] loss: 0.0068 perplexity: 1.0069 time: 7.83 epoch: 88/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.03it/s] loss: 0.0071 perplexity: 1.0071 time: 7.87 epoch: 89/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.78it/s] loss: 0.0072 perplexity: 1.0073 time: 7.85 epoch: 90/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.41it/s] loss: 0.0042 perplexity: 1.0042 time: 7.86 epoch: 91/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.20it/s] loss: 0.0045 perplexity: 1.0045 time: 7.86 epoch: 92/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.45it/s] loss: 0.0064 perplexity: 1.0064 time: 7.86 epoch: 93/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.39it/s] loss: 0.0071 perplexity: 1.0072 time: 7.86 epoch: 94/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.24it/s] loss: 0.0033 perplexity: 1.0033 time: 7.86 epoch: 95/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.83it/s] loss: 0.0056 perplexity: 1.0056 time: 7.84 epoch: 96/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.29it/s] loss: 0.0045 perplexity: 1.0045 time: 7.86 epoch: 97/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.27it/s] loss: 0.0070 perplexity: 1.0070 time: 7.86 epoch: 98/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.49it/s] loss: 0.0063 perplexity: 1.0063 time: 7.86 epoch: 99/200: 100%|██████████████████████▉| 1999/2000 [00:07<00:00, 254.47it/s] loss: 0.0078 perplexity: 1.0079 time: 7.86 epoch: 100/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.43it/s] loss: 0.0058 perplexity: 1.0058 time: 7.86 epoch: 101/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.47it/s] loss: 0.0032 perplexity: 1.0032 time: 7.86 epoch: 102/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.51it/s] loss: 0.0026 perplexity: 1.0026 time: 7.85 epoch: 103/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.71it/s] loss: 0.0023 perplexity: 1.0023 time: 7.88 epoch: 104/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.30it/s] loss: 0.0050 perplexity: 1.0051 time: 7.86 epoch: 105/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.77it/s] loss: 0.0034 perplexity: 1.0034 time: 7.85 epoch: 106/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.46it/s] loss: 0.0060 perplexity: 1.0060 time: 7.86 epoch: 107/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.86it/s] loss: 0.0055 perplexity: 1.0055 time: 7.84 epoch: 108/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.26it/s] loss: 0.0037 perplexity: 1.0037 time: 7.86 epoch: 109/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.88it/s] loss: 0.0036 perplexity: 1.0036 time: 7.84 epoch: 110/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.20it/s] loss: 0.0060 perplexity: 1.0060 time: 7.86 epoch: 111/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.57it/s] loss: 0.0069 perplexity: 1.0069 time: 7.85 epoch: 112/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.06it/s] loss: 0.0050 perplexity: 1.0051 time: 7.87 epoch: 113/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.65it/s] loss: 0.0052 perplexity: 1.0052 time: 7.85 epoch: 114/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.84it/s] loss: 0.0048 perplexity: 1.0048 time: 7.88 epoch: 115/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.92it/s] loss: 0.0018 perplexity: 1.0018 time: 7.87 epoch: 116/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 255.27it/s] loss: 0.0028 perplexity: 1.0028 time: 7.83 epoch: 117/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.87it/s] loss: 0.0053 perplexity: 1.0053 time: 7.84 epoch: 118/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.70it/s] loss: 0.0091 perplexity: 1.0091 time: 7.85 epoch: 119/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.37it/s] loss: 0.0084 perplexity: 1.0085 time: 7.86 epoch: 120/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.54it/s] loss: 0.0036 perplexity: 1.0036 time: 7.85 epoch: 121/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.88it/s] loss: 0.0020 perplexity: 1.0020 time: 7.84 epoch: 122/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.53it/s] loss: 0.0032 perplexity: 1.0032 time: 7.85 epoch: 123/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.54it/s] loss: 0.0046 perplexity: 1.0046 time: 7.85 epoch: 124/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.87it/s] loss: 0.0052 perplexity: 1.0052 time: 7.84 epoch: 125/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.94it/s] loss: 0.0035 perplexity: 1.0035 time: 7.87 epoch: 126/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.30it/s] loss: 0.0020 perplexity: 1.0020 time: 7.86 epoch: 127/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.46it/s] loss: 0.0008 perplexity: 1.0008 time: 7.86 epoch: 128/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.16it/s] loss: 0.0008 perplexity: 1.0008 time: 7.87 epoch: 129/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.23it/s] loss: 0.0100 perplexity: 1.0100 time: 7.86 epoch: 130/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.53it/s] loss: 0.0016 perplexity: 1.0016 time: 7.85 epoch: 131/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.04it/s] loss: 0.0046 perplexity: 1.0046 time: 7.87 epoch: 132/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.61it/s] loss: 0.0045 perplexity: 1.0045 time: 7.88 epoch: 133/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 254.14it/s] loss: 0.0033 perplexity: 1.0033 time: 7.87 epoch: 134/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.24it/s] loss: 0.0051 perplexity: 1.0051 time: 7.93 epoch: 135/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.64it/s] loss: 0.0075 perplexity: 1.0075 time: 7.94 epoch: 136/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.30it/s] loss: 0.0029 perplexity: 1.0029 time: 7.92 epoch: 137/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.53it/s] loss: 0.0063 perplexity: 1.0063 time: 7.92 epoch: 138/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.68it/s] loss: 0.0021 perplexity: 1.0021 time: 7.91 epoch: 139/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.86it/s] loss: 0.0032 perplexity: 1.0032 time: 7.91 epoch: 140/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.73it/s] loss: 0.0016 perplexity: 1.0016 time: 7.91 epoch: 141/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.61it/s] loss: 0.0073 perplexity: 1.0074 time: 7.91 epoch: 142/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.73it/s] loss: 0.0060 perplexity: 1.0060 time: 7.91 epoch: 143/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.68it/s] loss: 0.0055 perplexity: 1.0055 time: 7.91 epoch: 144/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.28it/s] loss: 0.0047 perplexity: 1.0047 time: 7.92 epoch: 145/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.38it/s] loss: 0.0034 perplexity: 1.0034 time: 7.89 epoch: 146/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.21it/s] loss: 0.0045 perplexity: 1.0045 time: 7.90 epoch: 147/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.16it/s] loss: 0.0012 perplexity: 1.0012 time: 7.90 epoch: 148/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.54it/s] loss: 0.0025 perplexity: 1.0025 time: 7.92 epoch: 149/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.00it/s] loss: 0.0096 perplexity: 1.0096 time: 7.90 epoch: 150/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.11it/s] loss: 0.0038 perplexity: 1.0038 time: 7.90 epoch: 151/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.43it/s] loss: 0.0020 perplexity: 1.0020 time: 7.92 epoch: 152/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 250.86it/s] loss: 0.0017 perplexity: 1.0017 time: 7.97 epoch: 153/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.76it/s] loss: 0.0012 perplexity: 1.0012 time: 7.94 epoch: 154/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 250.49it/s] loss: 0.0038 perplexity: 1.0038 time: 7.98 epoch: 155/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.16it/s] loss: 0.0036 perplexity: 1.0036 time: 7.96 epoch: 156/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.37it/s] loss: 0.0089 perplexity: 1.0089 time: 7.95 epoch: 157/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.27it/s] loss: 0.0028 perplexity: 1.0028 time: 7.96 epoch: 158/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.42it/s] loss: 0.0022 perplexity: 1.0022 time: 7.95 epoch: 159/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 250.53it/s] loss: 0.0079 perplexity: 1.0080 time: 7.98 epoch: 160/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 250.70it/s] loss: 0.0052 perplexity: 1.0052 time: 7.97 epoch: 161/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 250.84it/s] loss: 0.0039 perplexity: 1.0039 time: 7.97 epoch: 162/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.16it/s] loss: 0.0022 perplexity: 1.0022 time: 7.96 epoch: 163/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.15it/s] loss: 0.0024 perplexity: 1.0024 time: 7.96 epoch: 164/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.18it/s] loss: 0.0044 perplexity: 1.0044 time: 7.96 epoch: 165/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.06it/s] loss: 0.0055 perplexity: 1.0055 time: 7.96 epoch: 166/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 250.91it/s] loss: 0.0050 perplexity: 1.0050 time: 7.97 epoch: 167/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.56it/s] loss: 0.0049 perplexity: 1.0049 time: 7.95 epoch: 168/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.01it/s] loss: 0.0032 perplexity: 1.0032 time: 7.96 epoch: 169/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.17it/s] loss: 0.0018 perplexity: 1.0018 time: 7.96 epoch: 170/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.03it/s] loss: 0.0055 perplexity: 1.0055 time: 7.96 epoch: 171/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.78it/s] loss: 0.0065 perplexity: 1.0065 time: 7.94 epoch: 172/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.03it/s] loss: 0.0023 perplexity: 1.0023 time: 7.96 epoch: 173/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.70it/s] loss: 0.0035 perplexity: 1.0035 time: 7.94 epoch: 174/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.40it/s] loss: 0.0099 perplexity: 1.0099 time: 7.95 epoch: 175/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.24it/s] loss: 0.0042 perplexity: 1.0042 time: 7.96 epoch: 176/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 250.65it/s] loss: 0.0028 perplexity: 1.0028 time: 7.98 epoch: 177/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.46it/s] loss: 0.0036 perplexity: 1.0036 time: 7.95 epoch: 178/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.10it/s] loss: 0.0026 perplexity: 1.0026 time: 7.96 epoch: 179/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.08it/s] loss: 0.0013 perplexity: 1.0013 time: 7.96 epoch: 180/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.25it/s] loss: 0.0057 perplexity: 1.0057 time: 7.96 epoch: 181/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.13it/s] loss: 0.0048 perplexity: 1.0048 time: 7.96 epoch: 182/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.18it/s] loss: 0.0018 perplexity: 1.0018 time: 7.96 epoch: 183/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.06it/s] loss: 0.0012 perplexity: 1.0012 time: 7.96 epoch: 184/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.32it/s] loss: 0.0019 perplexity: 1.0019 time: 7.95 epoch: 185/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.09it/s] loss: 0.0093 perplexity: 1.0093 time: 7.96 epoch: 186/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.36it/s] loss: 0.0046 perplexity: 1.0046 time: 7.95 epoch: 187/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.35it/s] loss: 0.0044 perplexity: 1.0044 time: 7.95 epoch: 188/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.63it/s] loss: 0.0032 perplexity: 1.0032 time: 7.94 epoch: 189/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.14it/s] loss: 0.0030 perplexity: 1.0030 time: 7.96 epoch: 190/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.57it/s] loss: 0.0024 perplexity: 1.0024 time: 7.95 epoch: 191/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.22it/s] loss: 0.0032 perplexity: 1.0033 time: 7.96 epoch: 192/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.69it/s] loss: 0.0033 perplexity: 1.0033 time: 7.94 epoch: 193/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.44it/s] loss: 0.0035 perplexity: 1.0036 time: 7.95 epoch: 194/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.27it/s] loss: 0.0025 perplexity: 1.0025 time: 7.96 epoch: 195/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 251.54it/s] loss: 0.0029 perplexity: 1.0029 time: 7.95 epoch: 196/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 250.87it/s] loss: 0.0022 perplexity: 1.0022 time: 7.97 epoch: 197/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.08it/s] loss: 0.0028 perplexity: 1.0028 time: 7.90 epoch: 198/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 253.00it/s] loss: 0.0030 perplexity: 1.0030 time: 7.90 epoch: 199/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.82it/s] loss: 0.0054 perplexity: 1.0055 time: 7.91 epoch: 200/200: 100%|█████████████████████▉| 1999/2000 [00:07<00:00, 252.63it/s]