Gradient Checkpoint 是一种能够节省内存的技术。什么时候需要节省内存呢?比如:模型太大,无法放到西有限的显存中训练。或者模型能够放到显存中,但是只能使用较小的 batch size, 我们知道有时候使用较小的 batch size 可能会导致模型无法收敛。在这种情况下,我们就需要一种技术来降低训练过程中模型占用显存的大小,从而加大 batch size。
Gradient Checkpoint 就是一种节省训练过程中显存占用的技术,或者说它是一种用时间换空间的技术。需要注意的是 Gradient Checkpoint 可能会导致训练时间增加,并不能加快训练速度。
Gradient Checkpoint 是如何做到时间换空间的呢?
我们知道模型训练前向计算过程中,需要保存大量的用于反向梯度计算的中间结果,这部分的显存占了相当可观的比重。Gradient Checkpoint 就是通过只缓存部分中间结果的方式来减少训练过程中的显存占用。我们刚提到过,缓存中间结果的目的是为了 backward 时计算梯度。如果我们只缓存部分的话, 某些反向梯度计算时需要的中间结果,只能再进行一次正向计算来获得。从这点来看的话,会额外增加一些计算任务,从而导致模型训练时间变长,但是却得到了更多的显存资源。这也是我们说的以时间换空间。
如何在自己的模型中使用 Gradient Checkpoint 呢?
我们可以使用 torch.utils.checkpoint
来实现,下面给出一个自己写的简单示意代码:
import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint import torch.cuda as cuda import pynvml pynvml.nvmlInit() device_object = pynvml.nvmlDeviceGetHandleByIndex(0) def show_usage(): device_memory = pynvml.nvmlDeviceGetMemoryInfo(device_object) total = device_memory.total used = device_memory.used free = device_memory.free print('总共:', total, '使用:', used, '剩余:', free) class Net(nn.Module): def __init__(self, gradient_ceckpoint=False): super(Net, self).__init__() self.gradient_ceckpoint = gradient_ceckpoint self.linear1 = nn.Linear(1024, 1024 * 10) self.block1 = nn.Sequential(*[nn.Linear(1024 * 10, 1024 * 10), nn.Tanh()]) self.block2 = nn.Sequential(*[nn.Linear(1024 * 10, 1024 * 10), nn.Tanh()]) self.block3 = nn.Sequential(*[nn.Linear(1024 * 10, 1024 * 10), nn.Tanh()]) def forward(self, inputs): if self.gradient_ceckpoint: inputs = self.linear1(inputs) inputs = checkpoint(self.block1, inputs) inputs = checkpoint(self.block2, inputs) output = self.block3(inputs) else: inputs = self.linear1(inputs) inputs = self.block1(inputs) inputs = self.block2(inputs) output = self.block3(inputs) return output if __name__ == '__main__': show_usage() print(cuda.memory_allocated()) model = Net(gradient_ceckpoint=False).cuda() model.train() for _ in range(2): inputs = torch.randn(size=[512, 1024]).cuda() print(cuda.memory_allocated()) output = model(inputs) print(cuda.memory_allocated()) loss = torch.mean(output) print(cuda.memory_allocated()) loss.backward() print(cuda.memory_allocated()) print('-' * 50) show_usage()
参考:https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs