模型剪枝(Model Pruning)

对模型进行剪枝,使得模型参数稀疏化可以降低模型的复杂度,也能够一定程度上加快模型的计算速度。我们知道决策树通过剪枝能够起到正则化,防止过拟合。在深度学习模型中,裁剪也能够起到相应的作用。

模型的裁剪本质上是将部分的模型参数置为0,从而实现模型参数稀疏,以达到加快模型计算,防止过拟合。在 PyTorch 中提供了随机裁剪、根据维度的范数裁剪、全局裁剪,还有自定义裁剪的方法。接下来,我们就了解下其裁剪的过程。

裁剪的流程:先进行训练,得到 SOTA 模型,然后进行裁剪。裁剪之后,模型的精度肯定会受到影响,要对模型进行评估。如果需要,还可以进行微调(提高精度),再进行裁剪。

注意:再次微调时,被裁剪过的参数不会被更新,仍然为 0.

import torch
import torch.nn as nn
from torch.nn.utils.prune import random_structured
from torch.nn.utils.prune import random_unstructured
from torch.nn.utils.prune import l1_unstructured
from torch.nn.utils.prune import ln_structured
from torch.nn.utils.prune import global_unstructured
from torch.nn.utils.prune import L1Unstructured
from torch.nn.utils.prune import RandomUnstructured
from torch.nn.utils.prune import RandomStructured
from torch.nn.utils.prune import LnStructured


# 用于裁剪的简单网络
class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.weight = nn.Parameter(torch.tensor([[1., 2., 3.], [4., 5., 6.]]))
        self.bias = nn.Parameter(torch.tensor([10., 20., 30]))

1. 随机裁剪

PyTorch 中对于模型参数的随机裁剪提供了 random_structured 和 random_unstructured 两种 AP,两者的区别如下:

  1. random_structured 裁剪时需要指定对参数的裁剪维度,即: dim 参数;
  2. random_unstructured 则是针对所有参数进行随机裁剪,不需要指定 dim 参数;

随机裁剪在工作时,会根据指定的维度或者所有参数,将 amout 量的参数置为 0,amout 如果是 float,表示裁剪的比重,如果是 int 类型则表示指定具体的裁剪参数数量。

示例代码:

def test01():

    model = Model()
    print(model.state_dict()['weight'])
    print(model.state_dict()['bias'])

    print('-' * 30)

    # Prune entire (currently unpruned) channels in a tensor at random.
    # 随机在某个 dim 上稀疏一个维度的参数
    m = random_structured(model, name='weight', dim=1, amount=2)
    print(m.state_dict()['weight_orig'])
    print(m.state_dict()['weight_mask'])

    print('-' * 30)

    # Prune (currently unpruned) units in a tensor at random.
    # 随机稀疏掉指定数量的参数为0
    model = Model()
    m = random_unstructured(model, name='weight', amount=1)
    print(m.state_dict()['weight_orig'])
    print(m.state_dict()['weight_mask'])

if __name__ == '__main__':
    test01()

程序执行结果:

tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([10., 20., 30.])
------------------------------
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[0., 0., 1.],
        [0., 0., 1.]])
------------------------------
OrderedDict([('weight', tensor([[1., 2., 3.],
        [4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))])
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 1., 1.],
        [1., 1., 0.]])

注意:裁剪完成之后,原来 name 名字将会消失,取而代之的是 name_mask 和 name_orig,这两个的计算结果就是 name 对应的参数,name_mask 为 0 的地方表示被设置为 0 的参数。

另外需要注意的是,函数调用完毕之后会直接修改原来的模型,返回值则可以查看 name_mask 和 name_orig 的值。

2. 范数裁剪

我们也可以根据范数对模型参数进行裁剪,PyTorch 提供了基于 L1 范数的 l1_unstructured 和指定范数的 ln_structured 裁剪。这里你可以发现,l1_unstructured 是非结构化的裁剪策略,它裁剪时不需要指定裁剪维度。而 ln_structured 是结构化的裁剪策略,需要指定 dim 和范数。

  • l1_unstructured:将张量中的所有元素按照绝对值大小进行排序,然后将排序后绝对值最小的一定比例的元素设置为零。这种方法不考虑张量的形状和结构,所以是无结构化的剪枝。
  • l1_structured:将张量分解为多个小部分,并将每个小不分中的元素按照绝对值大小进行排序,然后将排序后绝对值最小的一定比例的元素设置为零。这种方法考虑张量的形状和结构,所以它是结构化的剪枝。

示例代码:

def test02():

    model = Model()
    print(model.state_dict())
    # Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.
    # L1范数最低的就是权值最低的参数置为0
    m = l1_unstructured(model, name='weight', amount=2)
    print(m.state_dict()['weight_orig'])
    print(m.state_dict()['weight_mask'])


def test03():
    model = Model()
    print(model.state_dict())
    # Prune entire (currently unpruned) channels in a tensor based on their n-norm
    # 可以通过 dim 指定维度,并且通过 n 指定对 dim 维度的范数计算,将最小的 dim 参数置为0
    m = ln_structured(model, name='weight', amount=1, n=1, dim=0)
    print(m.state_dict()['weight_orig'])
    print(m.state_dict()['weight_mask'])


if __name__ == '__main__':
    test02()
    print('-' * 30)
    test03()

程序执行结果:

OrderedDict([('weight', tensor([[1., 2., 3.],
        [4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))])
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[0., 0., 1.],
        [1., 1., 1.]])
------------------------------
OrderedDict([('weight', tensor([[1., 2., 3.],
        [4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))])
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[0., 0., 0.],
        [1., 1., 1.]])

3. 全局裁剪

全局裁剪指的是一次性指定要裁剪的所有参数,并从所有参数的视角,根据 pruning_method 策略裁剪 amount 数量的参数。pruning_method 可选的参数值就是前面介绍的几种方法(当然也可以是自定义的):

  1. from torch.nn.utils.prune import L1Unstructured
  2. from torch.nn.utils.prune import RandomUnstructured
  3. from torch.nn.utils.prune import RandomStructured
  4. from torch.nn.utils.prune import LnStructured

注意:这个参数并不是函数形式,是类的形式。

def test04():

    model = Model()
    print(model.state_dict())

    # 定义要裁剪的参数
    prune_param = ((model, 'bias'), (model, 'weight'))
    # pruning_method 传递剪枝类名
    # 从全局角度进行 amout 数量的参数置0
    global_unstructured(parameters=prune_param,
                        pruning_method=RandomUnstructured,
                        amount=2)

    print(model.state_dict())


if __name__ == '__main__':
    test04()

程序执行结果:

OrderedDict([('weight', tensor([[1., 2., 3.],
        [4., 5., 6.]])), ('bias', tensor([10., 20., 30.]))])
OrderedDict([('bias_orig', tensor([10., 20., 30.])), ('weight_orig', tensor([[1., 2., 3.],
        [4., 5., 6.]])), ('bias_mask', tensor([0., 1., 1.])), ('weight_mask', tensor([[1., 1., 1.],
        [0., 1., 1.]]))])

4. remove 函数

def test05():

    from torch.nn.utils.prune import remove
    import torch.nn.utils.prune as prune
    
    model = Model()
    print(model.state_dict(), prune.is_pruned(model))  # False
    m = l1_unstructured(model, name='weight', amount=2)
    print(model.state_dict(), prune.is_pruned(model))  # True
    # 将 name_orig name_mask 恢复为 name,去除剪枝痕迹
    remove(model, name='weight')
    print(model.state_dict(), prune.is_pruned(model))  # False
未经允许不得转载:一亩三分地 » 模型剪枝(Model Pruning)
评论 (0)

6 + 8 =