对模型进行剪枝,使得模型参数稀疏化可以降低模型的复杂度,也能够一定程度上加快模型的计算速度。我们知道决策树通过剪枝能够起到正则化,防止过拟合。在深度学习模型中,裁剪也能够起到相应的作用。
模型的裁剪本质上是将部分的模型参数置为0,从而实现模型参数稀疏,以达到加快模型计算,防止过拟合。在 PyTorch 中提供了随机裁剪、根据维度的范数裁剪、全局裁剪,还有自定义裁剪的方法。接下来,我们就了解下其裁剪的过程。
裁剪的流程:先进行训练,得到 SOTA 模型,然后进行裁剪。裁剪之后,模型的精度肯定会受到影响,要对模型进行评估。如果需要,还可以进行微调(提高精度),再进行裁剪。
注意:再次微调时,被裁剪过的参数不会被更新,仍然为 0.
import torch
import torch.nn as nn
from torch.nn.utils.prune import random_structured
from torch.nn.utils.prune import random_unstructured
from torch.nn.utils.prune import l1_unstructured
from torch.nn.utils.prune import ln_structured
from torch.nn.utils.prune import global_unstructured
from torch.nn.utils.prune import L1Unstructured
from torch.nn.utils.prune import RandomUnstructured
from torch.nn.utils.prune import RandomStructured
from torch.nn.utils.prune import LnStructured
# 用于裁剪的简单网络
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.weight = nn.Parameter(torch.tensor([[1., 2., 3.], [4., 5., 6.]]))
self.bias = nn.Parameter(torch.tensor([10., 20., 30]))
1. 随机裁剪
PyTorch 中对于模型参数的随机裁剪提供了 random_structured 和 random_unstructured 两种 API,两者的区别如下:
- random_structured 裁剪时需要指定对参数的裁剪维度,即: dim 参数;
- random_unstructured 则是针对所有参数进行随机裁剪,不需要指定 dim 参数;
随机裁剪在工作时,会根据指定的维度或者所有参数,将 amout 数量的参数置为 0,amout 如果是 float,表示裁剪的比重,如果是 int 类型则表示指定具体的裁剪参数数量。
示例代码:
def test01():
model = Model()
print(model.state_dict()['weight'])
print(model.state_dict()['bias'])
print('-' * 30)
# Prune entire (currently unpruned) channels in a tensor at random.
# 随机在某个 dim 上稀疏一个维度的参数
m = random_structured(model, name='weight', dim=1, amount=2)
print(m.state_dict()['weight_orig'])
print(m.state_dict()['weight_mask'])
print('-' * 30)
# Prune (currently unpruned) units in a tensor at random.
# 随机稀疏掉指定数量的参数为0
model = Model()
m = random_unstructured(model, name='weight', amount=1)
print(m.state_dict()['weight_orig'])
print(m.state_dict()['weight_mask'])
if __name__ == '__main__':
test01()
程序执行结果:
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([10., 20., 30.])
------------------------------
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([[0., 0., 1.],
[0., 0., 1.]])
------------------------------
OrderedDict([('weight', tensor([[1., 2., 3.],
[4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))])
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([[1., 1., 1.],
[1., 1., 0.]])
注意:裁剪完成之后,原来 name 名字将会消失,取而代之的是 name_mask 和 name_orig,这两个的计算结果就是 name 对应的参数,name_mask 为 0 的地方表示被设置为 0 的参数。
另外需要注意的是,函数调用完毕之后会直接修改原来的模型,返回值则可以查看 name_mask 和 name_orig 的值。
2. 范数裁剪
我们也可以根据范数对模型参数进行裁剪,PyTorch 提供了基于 L1 范数的 l1_unstructured 和指定范数的 ln_structured 裁剪。这里你可以发现,l1_unstructured 是非结构化的裁剪策略,它裁剪时不需要指定裁剪维度。而 ln_structured 是结构化的裁剪策略,需要指定 dim 和范数。
- l1_unstructured:将张量中的所有元素按照绝对值大小进行排序,然后将排序后绝对值最小的一定比例的元素设置为零。这种方法不考虑张量的形状和结构,所以是无结构化的剪枝。
- l1_structured:将张量分解为多个小部分,并将每个小不分中的元素按照绝对值大小进行排序,然后将排序后绝对值最小的一定比例的元素设置为零。这种方法考虑张量的形状和结构,所以它是结构化的剪枝。
示例代码:
def test02():
model = Model()
print(model.state_dict())
# Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.
# L1范数最低的就是权值最低的参数置为0
m = l1_unstructured(model, name='weight', amount=2)
print(m.state_dict()['weight_orig'])
print(m.state_dict()['weight_mask'])
def test03():
model = Model()
print(model.state_dict())
# Prune entire (currently unpruned) channels in a tensor based on their n-norm
# 可以通过 dim 指定维度,并且通过 n 指定对 dim 维度的范数计算,将最小的 dim 参数置为0
m = ln_structured(model, name='weight', amount=1, n=1, dim=0)
print(m.state_dict()['weight_orig'])
print(m.state_dict()['weight_mask'])
if __name__ == '__main__':
test02()
print('-' * 30)
test03()
程序执行结果:
OrderedDict([('weight', tensor([[1., 2., 3.],
[4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))])
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([[0., 0., 1.],
[1., 1., 1.]])
------------------------------
OrderedDict([('weight', tensor([[1., 2., 3.],
[4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))])
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([[0., 0., 0.],
[1., 1., 1.]])
3. 全局裁剪
全局裁剪指的是一次性指定要裁剪的所有参数,并从所有参数的视角,根据 pruning_method 策略裁剪 amount 数量的参数。pruning_method 可选的参数值就是前面介绍的几种方法(当然也可以是自定义的):
- from torch.nn.utils.prune import L1Unstructured
- from torch.nn.utils.prune import RandomUnstructured
- from torch.nn.utils.prune import RandomStructured
- from torch.nn.utils.prune import LnStructured
注意:这个参数并不是函数形式,是类的形式。
def test04():
model = Model()
print(model.state_dict())
# 定义要裁剪的参数
prune_param = ((model, 'bias'), (model, 'weight'))
# pruning_method 传递剪枝类名
# 从全局角度进行 amout 数量的参数置0
global_unstructured(parameters=prune_param,
pruning_method=RandomUnstructured,
amount=2)
print(model.state_dict())
if __name__ == '__main__':
test04()
程序执行结果:
OrderedDict([('weight', tensor([[1., 2., 3.],
[4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))])
OrderedDict([('bias_orig', tensor([10., 20., 30.])), ('weight_orig', tensor([[1., 2., 3.],
[4., 5., 6.]])), ('bias_mask', tensor([0., 1., 1.])), ('weight_mask', tensor([[1., 1., 1.],
[0., 1., 1.]]))])
4. remove 函数
def test05():
from torch.nn.utils.prune import remove
import torch.nn.utils.prune as prune
model = Model()
print(model.state_dict(), prune.is_pruned(model)) # False
m = l1_unstructured(model, name='weight', amount=2)
print(model.state_dict(), prune.is_pruned(model)) # True
# 将 name_orig name_mask 恢复为 name,去除剪枝痕迹
remove(model, name='weight')
print(model.state_dict(), prune.is_pruned(model)) # False

冀公网安备13050302001966号