对模型进行剪枝,使得模型参数稀疏化可以降低模型的复杂度,也能够一定程度上加快模型的计算速度。我们知道决策树通过剪枝能够起到正则化,防止过拟合。在深度学习模型中,裁剪也能够起到相应的作用。
模型的裁剪本质上是将部分的模型参数置为0,从而实现模型参数稀疏,以达到加快模型计算,防止过拟合。在 PyTorch 中提供了随机裁剪、根据维度的范数裁剪、全局裁剪,还有自定义裁剪的方法。接下来,我们就了解下其裁剪的过程。
裁剪的流程:先进行训练,得到 SOTA 模型,然后进行裁剪。裁剪之后,模型的精度肯定会受到影响,要对模型进行评估。如果需要,还可以进行微调(提高精度),再进行裁剪。
注意:再次微调时,被裁剪过的参数不会被更新,仍然为 0.
import torch import torch.nn as nn from torch.nn.utils.prune import random_structured from torch.nn.utils.prune import random_unstructured from torch.nn.utils.prune import l1_unstructured from torch.nn.utils.prune import ln_structured from torch.nn.utils.prune import global_unstructured from torch.nn.utils.prune import L1Unstructured from torch.nn.utils.prune import RandomUnstructured from torch.nn.utils.prune import RandomStructured from torch.nn.utils.prune import LnStructured # 用于裁剪的简单网络 class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.weight = nn.Parameter(torch.tensor([[1., 2., 3.], [4., 5., 6.]])) self.bias = nn.Parameter(torch.tensor([10., 20., 30]))
1. 随机裁剪
PyTorch 中对于模型参数的随机裁剪提供了 random_structured 和 random_unstructured 两种 AP,两者的区别如下:
- random_structured 裁剪时需要指定对参数的裁剪维度,即: dim 参数;
- random_unstructured 则是针对所有参数进行随机裁剪,不需要指定 dim 参数;
随机裁剪在工作时,会根据指定的维度或者所有参数,将 amout 量的参数置为 0,amout 如果是 float,表示裁剪的比重,如果是 int 类型则表示指定具体的裁剪参数数量。
示例代码:
def test01(): model = Model() print(model.state_dict()['weight']) print(model.state_dict()['bias']) print('-' * 30) # Prune entire (currently unpruned) channels in a tensor at random. # 随机在某个 dim 上稀疏一个维度的参数 m = random_structured(model, name='weight', dim=1, amount=2) print(m.state_dict()['weight_orig']) print(m.state_dict()['weight_mask']) print('-' * 30) # Prune (currently unpruned) units in a tensor at random. # 随机稀疏掉指定数量的参数为0 model = Model() m = random_unstructured(model, name='weight', amount=1) print(m.state_dict()['weight_orig']) print(m.state_dict()['weight_mask']) if __name__ == '__main__': test01()
程序执行结果:
tensor([[1., 2., 3.], [4., 5., 6.]]) tensor([10., 20., 30.]) ------------------------------ tensor([[1., 2., 3.], [4., 5., 6.]]) tensor([[0., 0., 1.], [0., 0., 1.]]) ------------------------------ OrderedDict([('weight', tensor([[1., 2., 3.], [4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))]) tensor([[1., 2., 3.], [4., 5., 6.]]) tensor([[1., 1., 1.], [1., 1., 0.]])
注意:裁剪完成之后,原来 name 名字将会消失,取而代之的是 name_mask 和 name_orig,这两个的计算结果就是 name 对应的参数,name_mask 为 0 的地方表示被设置为 0 的参数。
另外需要注意的是,函数调用完毕之后会直接修改原来的模型,返回值则可以查看 name_mask 和 name_orig 的值。
2. 范数裁剪
我们也可以根据范数对模型参数进行裁剪,PyTorch 提供了基于 L1 范数的 l1_unstructured 和指定范数的 ln_structured 裁剪。这里你可以发现,l1_unstructured 是非结构化的裁剪策略,它裁剪时不需要指定裁剪维度。而 ln_structured 是结构化的裁剪策略,需要指定 dim 和范数。
- l1_unstructured:将张量中的所有元素按照绝对值大小进行排序,然后将排序后绝对值最小的一定比例的元素设置为零。这种方法不考虑张量的形状和结构,所以是无结构化的剪枝。
- l1_structured:将张量分解为多个小部分,并将每个小不分中的元素按照绝对值大小进行排序,然后将排序后绝对值最小的一定比例的元素设置为零。这种方法考虑张量的形状和结构,所以它是结构化的剪枝。
示例代码:
def test02(): model = Model() print(model.state_dict()) # Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm. # L1范数最低的就是权值最低的参数置为0 m = l1_unstructured(model, name='weight', amount=2) print(m.state_dict()['weight_orig']) print(m.state_dict()['weight_mask']) def test03(): model = Model() print(model.state_dict()) # Prune entire (currently unpruned) channels in a tensor based on their n-norm # 可以通过 dim 指定维度,并且通过 n 指定对 dim 维度的范数计算,将最小的 dim 参数置为0 m = ln_structured(model, name='weight', amount=1, n=1, dim=0) print(m.state_dict()['weight_orig']) print(m.state_dict()['weight_mask']) if __name__ == '__main__': test02() print('-' * 30) test03()
程序执行结果:
OrderedDict([('weight', tensor([[1., 2., 3.], [4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))]) tensor([[1., 2., 3.], [4., 5., 6.]]) tensor([[0., 0., 1.], [1., 1., 1.]]) ------------------------------ OrderedDict([('weight', tensor([[1., 2., 3.], [4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))]) tensor([[1., 2., 3.], [4., 5., 6.]]) tensor([[0., 0., 0.], [1., 1., 1.]])
3. 全局裁剪
全局裁剪指的是一次性指定要裁剪的所有参数,并从所有参数的视角,根据 pruning_method 策略裁剪 amount 数量的参数。pruning_method 可选的参数值就是前面介绍的几种方法(当然也可以是自定义的):
- from torch.nn.utils.prune import L1Unstructured
- from torch.nn.utils.prune import RandomUnstructured
- from torch.nn.utils.prune import RandomStructured
- from torch.nn.utils.prune import LnStructured
注意:这个参数并不是函数形式,是类的形式。
def test04(): model = Model() print(model.state_dict()) # 定义要裁剪的参数 prune_param = ((model, 'bias'), (model, 'weight')) # pruning_method 传递剪枝类名 # 从全局角度进行 amout 数量的参数置0 global_unstructured(parameters=prune_param, pruning_method=RandomUnstructured, amount=2) print(model.state_dict()) if __name__ == '__main__': test04()
程序执行结果:
OrderedDict([('weight', tensor([[1., 2., 3.], [4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))]) OrderedDict([('bias_orig', tensor([10., 20., 30.])), ('weight_orig', tensor([[1., 2., 3.], [4., 5., 6.]])), ('bias_mask', tensor([0., 1., 1.])), ('weight_mask', tensor([[1., 1., 1.], [0., 1., 1.]]))])
4. remove 函数
def test05(): from torch.nn.utils.prune import remove import torch.nn.utils.prune as prune model = Model() print(model.state_dict(), prune.is_pruned(model)) # False m = l1_unstructured(model, name='weight', amount=2) print(model.state_dict(), prune.is_pruned(model)) # True # 将 name_orig name_mask 恢复为 name,去除剪枝痕迹 remove(model, name='weight') print(model.state_dict(), prune.is_pruned(model)) # False