基于 GPT 训练对联模型 – 模型训练

这部分主要包含两部分:

  1. 对联模型类编写
  2. 训练函数的实现

1. 对联模型

首先,我们将加载数据处理时得到的 tokenizer-encode-tokenizer, 主要用于获得词表大小,当然,我们在前面处理时,可以单独将词表大小存储起来,这里就不用加载整个词表对象了。

然后,使用 GPT 默认对象参数来初始化一个未训练过的模型对象。

最后, 将 GPT 模型的最后一个隐藏层送入线性层得到预测输出,这里需要注意的是,我们这里使用的 GPT-2 模型的词嵌入维度是 768。forward 前向计算函数,最终会返回 Linear 的结果,和 GPT 的 past_key_values 结果。

past_key_values 该值表示历史输入的自注意力计算中的 key value,有了 past_key_values 只需要新输入的值计算其张量表示即可,已输入的值无须重复计算。

完整代码如下:

class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        tokenizer = BertTokenizer.from_pretrained('data/tokenizer-encode-tokenizer')
        config = GPT2Config()
        self.gpt = GPT2Model(config=config)
        self.out = nn.Linear(in_features=768, out_features=tokenizer.vocab_size)

    def forward(self, inputs, past_key_values=None):

        outputs = self.gpt(inputs, past_key_values=past_key_values)

        # 输入所有字的最后一个隐藏层的输出向量
        model_output = outputs.last_hidden_state
        model_hidden = outputs.past_key_values

        return self.out(model_output), model_hidden

2. 训练函数

训练的一些主要参数如下:

  1. batch_size 为 16,我的 GPU 显存是 6G,模型加载到显存之后,剩余的显存空间不足,最大也就设置为 16,设置其他的值,会提示显存不足无法训练。你可以根据自己的显存大小,来调整此值。
  2. 学习率为 5e-5
  3. epoch 数量为 30
  4. 优化方法为 AdamW
  5. 损失函数为 CrossEntropyLoss,为了方便计算损失,我这里将其 reduction 设置为 sum,计算总损失。

每一个 epoch 训练结束就存储下模型,存储的格式为:couplet-gpt2-%d.bin,我的电脑训练一个 epoch 31万+ 数据需要 31 分 30 秒左右。

def train():

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 初始化模型
    model = Model().to(device)
    # 加载数据
    train_data = load_from_disk('data/data-encode-corpus')
    # train_data.set_format('pytorch', device=device)

    # 损失函数
    criterion = nn.CrossEntropyLoss(reduction='sum')
    # 优化方法
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    # 训练轮数
    epochs = 30

    # 训练信息
    total_loss = 0.0
    total_iter = 0

    def start_train(inputs):

        # 模型输出
        model_inputs = inputs['inputs']
        model_inputs = torch.tensor(model_inputs, device=device)
        outputs, past_key_values = model(model_inputs)

        # 计算损失
        loss_sum = 0.0
        iter_num = 0
        model_labels = inputs['labels']

        for y_pred, y_true in zip(outputs, model_labels):
            # 截取预测有效部分
            y_true = torch.tensor(y_true, device=device)
            y_pred = y_pred[:len(y_true)]
            loss = criterion(y_pred, y_true)
            loss_sum += loss
            iter_num += len(y_true)

        mean_loss = loss_sum / iter_num
        # 梯度清零
        optimizer.zero_grad()
        # 反向传播
        mean_loss.backward()
        # 参数更新
        optimizer.step()

        nonlocal total_loss, total_iter
        total_loss += loss_sum
        total_iter += iter_num


    for epoch_idx in range(epochs):

        start = time.time()
        train_data.map(start_train, batched=True, batch_size=16)
        print('epoch: %d loss: %.5f time: %.2f' % (epoch_idx, total_loss / total_iter, time.time() - start))
        total_loss = 0.0
        total_iter = 0

        torch.save(model.state_dict(), 'model/couplet-gpt2-%d.bin' % (epoch_idx + 1))

2. 训练结果

经过 30 个 epoch 训练之后,最后的损失为 0.99273,如果再增加训练轮数,损失还会继续下降,我们暂且训练到这里。

100%|█████████████████████████████████████| 19552/19552 [31:28<00:00, 10.35ba/s]
epoch: 0 loss: 4.55596 time: 1889.14
100%|█████████████████████████████████████| 19552/19552 [31:33<00:00, 10.32ba/s]
epoch: 1 loss: 3.90209 time: 1895.03
100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s]
epoch: 2 loss: 3.63177 time: 1895.77
100%|█████████████████████████████████████| 19552/19552 [31:28<00:00, 10.35ba/s]
epoch: 3 loss: 3.41664 time: 1889.86
100%|█████████████████████████████████████| 19552/19552 [31:27<00:00, 10.36ba/s]
epoch: 4 loss: 3.21284 time: 1889.00
100%|█████████████████████████████████████| 19552/19552 [31:27<00:00, 10.36ba/s]
epoch: 5 loss: 3.01134 time: 1889.62
100%|█████████████████████████████████████| 19552/19552 [31:28<00:00, 10.35ba/s]
epoch: 6 loss: 2.81330 time: 1889.99
100%|█████████████████████████████████████| 19552/19552 [31:30<00:00, 10.34ba/s]
epoch: 7 loss: 2.62306 time: 1892.14
100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.34ba/s]
epoch: 8 loss: 2.44286 time: 1893.17
100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.34ba/s]
epoch: 9 loss: 2.27614 time: 1893.16
100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.34ba/s]
epoch: 10 loss: 2.12369 time: 1893.18
100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.34ba/s]
epoch: 11 loss: 1.98639 time: 1893.53
100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.33ba/s]
epoch: 12 loss: 1.86227 time: 1893.50
100%|█████████████████████████████████████| 19552/19552 [31:31<00:00, 10.33ba/s]
epoch: 13 loss: 1.75238 time: 1893.22
100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s]
epoch: 14 loss: 1.65421 time: 1894.55
100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s]
epoch: 15 loss: 1.56802 time: 1894.66
100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s]
epoch: 16 loss: 1.49034 time: 1894.05
100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s]
epoch: 17 loss: 1.42125 time: 1894.59
100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s]
epoch: 18 loss: 1.36117 time: 1896.22
100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s]
epoch: 19 loss: 1.30739 time: 1895.68
100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s]
epoch: 20 loss: 1.26003 time: 1896.38
100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s]
epoch: 21 loss: 1.21647 time: 1896.49
100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s]
epoch: 22 loss: 1.17920 time: 1896.52
100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s]
epoch: 23 loss: 1.14415 time: 1896.62
100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s]
epoch: 24 loss: 1.11351 time: 1896.89
100%|█████████████████████████████████████| 19552/19552 [31:34<00:00, 10.32ba/s]
epoch: 25 loss: 1.08483 time: 1896.75
100%|█████████████████████████████████████| 19552/19552 [31:35<00:00, 10.32ba/s]
epoch: 26 loss: 1.05824 time: 1897.03
100%|█████████████████████████████████████| 19552/19552 [31:33<00:00, 10.32ba/s]
epoch: 27 loss: 1.03451 time: 1895.75
100%|█████████████████████████████████████| 19552/19552 [31:32<00:00, 10.33ba/s]
epoch: 28 loss: 1.01332 time: 1894.80
100%|█████████████████████████████████████| 19552/19552 [31:33<00:00, 10.32ba/s]
epoch: 29 loss: 0.99273 time: 1895.03

未经允许不得转载:一亩三分地 » 基于 GPT 训练对联模型 – 模型训练