这部分主要包含两部分:
- 对联模型类编写
- 训练函数的实现
1. 对联模型
首先,我们将加载数据处理时得到的 tokenizer-encode-tokenizer, 主要用于获得词表大小,当然,我们在前面处理时,可以单独将词表大小存储起来,这里就不用加载整个词表对象了。
然后,使用 GPT 默认对象参数来初始化一个未训练过的模型对象。
最后, 将 GPT 模型的最后一个隐藏层送入线性层得到预测输出,这里需要注意的是,我们这里使用的 GPT-2 模型的词嵌入维度是 768。forward 前向计算函数,最终会返回 Linear 的结果,和 GPT 的 past_key_values 结果。
past_key_values 该值表示历史输入的自注意力计算中的 key value,有了 past_key_values 只需要新输入的值计算其张量表示即可,已输入的值无须重复计算。
完整代码如下:
class Model(nn.Module): def __init__(self): super(Model, self).__init__() tokenizer = BertTokenizer.from_pretrained('data/tokenizer-encode-tokenizer') config = GPT2Config() self.gpt = GPT2Model(config=config) self.out = nn.Linear(in_features=768, out_features=tokenizer.vocab_size) def forward(self, inputs, past_key_values=None): outputs = self.gpt(inputs, past_key_values=past_key_values) # 输入所有字的最后一个隐藏层的输出向量 model_output = outputs.last_hidden_state model_hidden = outputs.past_key_values return self.out(model_output), model_hidden
2. 训练函数
训练的一些主要参数如下:
- batch_size 为 16,我的 GPU 显存是 6G,模型加载到显存之后,剩余的显存空间不足,最大也就设置为 16,设置其他的值,会提示显存不足无法训练。你可以根据自己的显存大小,来调整此值。
- 学习率为 5e-5
- epoch 数量为 30
- 优化方法为 AdamW
- 损失函数为 CrossEntropyLoss,为了方便计算损失,我这里将其 reduction 设置为 sum,计算总损失。
每一个 epoch 训练结束就存储下模型,存储的格式为:couplet-gpt2-%d.bin,我的电脑训练一个 epoch 31万+ 数据需要 31 分 30 秒左右。
def train(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 初始化模型 model = Model().to(device) # 加载数据 train_data = load_from_disk('data/data-encode-corpus') # train_data.set_format('pytorch', device=device) # 损失函数 criterion = nn.CrossEntropyLoss(reduction='sum') # 优化方法 optimizer = optim.AdamW(model.parameters(), lr=5e-5) # 训练轮数 epochs = 30 # 训练信息 total_loss = 0.0 total_iter = 0 def start_train(inputs): # 模型输出 model_inputs = inputs['inputs'] model_inputs = torch.tensor(model_inputs, device=device) outputs, past_key_values = model(model_inputs) # 计算损失 loss_sum = 0.0 iter_num = 0 model_labels = inputs['labels'] for y_pred, y_true in zip(outputs, model_labels): # 截取预测有效部分 y_true = torch.tensor(y_true, device=device) y_pred = y_pred[:len(y_true)] loss = criterion(y_pred, y_true) loss_sum += loss iter_num += len(y_true) mean_loss = loss_sum / iter_num # 梯度清零 optimizer.zero_grad() # 反向传播 mean_loss.backward() # 参数更新 optimizer.step() nonlocal total_loss, total_iter total_loss += loss_sum total_iter += iter_num for epoch_idx in range(epochs): start = time.time() train_data.map(start_train, batched=True, batch_size=16) print('epoch: %d loss: %.5f time: %.2f' % (epoch_idx, total_loss / total_iter, time.time() - start)) total_loss = 0.0 total_iter = 0 torch.save(model.state_dict(), 'model/couplet-gpt2-%d.bin' % (epoch_idx + 1))
2. 训练结果
经过 30 个 epoch 训练之后,最后的损失为 0.99273,如果再增加训练轮数,损失还会继续下降,我们暂且训练到这里。
100%|█████████████████████████████████████| 19552/19552 [31:28<00:00, 10.35ba/s] epoch: 0 loss: 4.55596 time: 1889.14 100%|█████████████████████████████████████| 19552/19552 [31:33<00:00, 10.32ba/s] epoch: 1 loss: 3.90209 time: 1895.03 100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s] epoch: 2 loss: 3.63177 time: 1895.77 100%|█████████████████████████████████████| 19552/19552 [31:28<00:00, 10.35ba/s] epoch: 3 loss: 3.41664 time: 1889.86 100%|█████████████████████████████████████| 19552/19552 [31:27<00:00, 10.36ba/s] epoch: 4 loss: 3.21284 time: 1889.00 100%|█████████████████████████████████████| 19552/19552 [31:27<00:00, 10.36ba/s] epoch: 5 loss: 3.01134 time: 1889.62 100%|█████████████████████████████████████| 19552/19552 [31:28<00:00, 10.35ba/s] epoch: 6 loss: 2.81330 time: 1889.99 100%|█████████████████████████████████████| 19552/19552 [31:30<00:00, 10.34ba/s] epoch: 7 loss: 2.62306 time: 1892.14 100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.34ba/s] epoch: 8 loss: 2.44286 time: 1893.17 100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.34ba/s] epoch: 9 loss: 2.27614 time: 1893.16 100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.34ba/s] epoch: 10 loss: 2.12369 time: 1893.18 100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.34ba/s] epoch: 11 loss: 1.98639 time: 1893.53 100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.33ba/s] epoch: 12 loss: 1.86227 time: 1893.50 100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.33ba/s] epoch: 13 loss: 1.75238 time: 1893.22 100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s] epoch: 14 loss: 1.65421 time: 1894.55 100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s] epoch: 15 loss: 1.56802 time: 1894.66 100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s] epoch: 16 loss: 1.49034 time: 1894.05 100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s] epoch: 17 loss: 1.42125 time: 1894.59 100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s] epoch: 18 loss: 1.36117 time: 1896.22 100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s] epoch: 19 loss: 1.30739 time: 1895.68 100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s] epoch: 20 loss: 1.26003 time: 1896.38 100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s] epoch: 21 loss: 1.21647 time: 1896.49 100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s] epoch: 22 loss: 1.17920 time: 1896.52 100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s] epoch: 23 loss: 1.14415 time: 1896.62 100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s] epoch: 24 loss: 1.11351 time: 1896.89 100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s] epoch: 25 loss: 1.08483 time: 1896.75 100%|█████████████████████████████████████| 19552/19552 [31:35<00:00, 10.32ba/s] epoch: 26 loss: 1.05824 time: 1897.03 100%|█████████████████████████████████████| 19552/19552 [31:33<00:00, 10.32ba/s] epoch: 27 loss: 1.03451 time: 1895.75 100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s] epoch: 28 loss: 1.01332 time: 1894.80 100%|█████████████████████████████████████| 19552/19552 [31:33<00:00, 10.32ba/s] epoch: 29 loss: 0.99273 time: 1895.03