Gradient Checkpoint

Gradient Checkpoint 是一种能够节省内存的技术。什么时候需要节省内存呢?比如:模型太大,无法放到西有限的显存中训练。或者模型能够放到显存中,但是只能使用较小的 batch size, 我们知道有时候使用较小的 batch size 可能会导致模型无法收敛。在这种情况下,我们就需要一种技术来降低训练过程中模型占用显存的大小,从而加大 batch size。

Gradient Checkpoint 就是一种节省训练过程中显存占用的技术,或者说它是一种用时间换空间的技术。需要注意的是 Gradient Checkpoint 可能会导致训练时间增加,并不能加快训练速度。

Gradient Checkpoint 是如何做到时间换空间的呢?

我们知道模型训练前向计算过程中,需要保存大量的用于反向梯度计算的中间结果,这部分的显存占了相当可观的比重。Gradient Checkpoint 就是通过只缓存部分中间结果的方式来减少训练过程中的显存占用。我们刚提到过,缓存中间结果的目的是为了 backward 时计算梯度。如果我们只缓存部分的话, 某些反向梯度计算时需要的中间结果,只能再进行一次正向计算来获得。从这点来看的话,会额外增加一些计算任务,从而导致模型训练时间变长,但是却得到了更多的显存资源。这也是我们说的以时间换空间。

如何在自己的模型中使用 Gradient Checkpoint 呢?

我们可以使用 torch.utils.checkpoint 来实现,下面给出一个自己写的简单示意代码:

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import torch.cuda as cuda
import pynvml


pynvml.nvmlInit()
device_object = pynvml.nvmlDeviceGetHandleByIndex(0)

def show_usage():

    device_memory = pynvml.nvmlDeviceGetMemoryInfo(device_object)
    total = device_memory.total
    used  = device_memory.used
    free  = device_memory.free
    print('总共:', total, '使用:', used, '剩余:', free)


class Net(nn.Module):

    def __init__(self, gradient_ceckpoint=False):
        super(Net, self).__init__()
        self.gradient_ceckpoint = gradient_ceckpoint
        self.linear1 = nn.Linear(1024, 1024 * 10)
        self.block1 = nn.Sequential(*[nn.Linear(1024 * 10, 1024 * 10), nn.Tanh()])
        self.block2 = nn.Sequential(*[nn.Linear(1024 * 10, 1024 * 10), nn.Tanh()])
        self.block3 = nn.Sequential(*[nn.Linear(1024 * 10, 1024 * 10), nn.Tanh()])


    def forward(self, inputs):

        if self.gradient_ceckpoint:
            inputs = self.linear1(inputs)
            inputs = checkpoint(self.block1, inputs)
            inputs = checkpoint(self.block2, inputs)
            output = self.block3(inputs)
        else:
            inputs = self.linear1(inputs)
            inputs = self.block1(inputs)
            inputs = self.block2(inputs)
            output = self.block3(inputs)

        return output


if __name__ == '__main__':

    show_usage()
    print(cuda.memory_allocated())
    model = Net(gradient_ceckpoint=False).cuda()
    model.train()

    for _ in range(2):
        inputs = torch.randn(size=[512, 1024]).cuda()
        print(cuda.memory_allocated())
        output = model(inputs)
        print(cuda.memory_allocated())
        loss = torch.mean(output)
        print(cuda.memory_allocated())
        loss.backward()
        print(cuda.memory_allocated())
        print('-' * 50)

    show_usage()

参考:https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs

未经允许不得转载:一亩三分地 » Gradient Checkpoint
评论 (0)

4 + 3 =