恒等映射层 nn.Identity 的作用

nn.Identity 在一些编码中会看到,它的实现代码非常简单,仅仅是将输入的内容原封不动的输出。请看它在 PyTorch 中的实现代码:

class Identity(Module):
    r"""A placeholder identity operator that is argument-insensitive.

    Args:
        args: any argument (unused)
        kwargs: any keyword argument (unused)

    Shape:
        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
        - Output: :math:`(*)`, same shape as the input.

    Examples::

        >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 20])

    """
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(Identity, self).__init__()

    def forward(self, input: Tensor) -> Tensor:
        return input

其使用的时候,我们一般按照下面的方式使用:

import torch
import torch.nn as nn


def test01():

    a = torch.tensor([100, 200])
    b = nn.Identity()(a)

    print('a =', a, id(a))
    print('b =', b, id(b))


if __name__ == '__main__':
    test01()

程序输出结果:

a = tensor([1000,  200]) 140279002222128
b = tensor([1000,  200]) 140279002222128

从程序的执行结果,可以看到,nn.Identity 仅仅是输入张量,返回张量本身而已。这样的一个实现,在大多数人看来似乎没有什么用,我根据自己的使用经验总结了下两个场景下的作用:

  1. 提高代码可读性
  2. 增减网络,达到简化网络的作用

接下来,我们分别看下这两个场景下 nn.Identity 是如何发挥作用的。

1. 代码可读性

下面的代码,将输入做各种变换之后,并跨层进行残差连接:

class NetWork(nn.Module):

    def __init__(self):
        super(NetWork, self).__init__()

        self.linear1 = nn.Linear(100, 200)
        self.linear2 = nn.Linear(200, 400)
        self.relu = nn.ReLU(inplace=True)


    def forward(self, inputs):
        residual = inputs

        inputs = self.linear1(inputs)
        self.relu(inputs)
        inputs = self.linear2(inputs)
        self.relu(inputs)

        # 残差连接
        inputs += residual

        return inputs

残差连接部分是:inputs += residual,这个代码看起来没什么,如果没有注释的话,可能并不能知道这行代码的作用,可读性稍微差一些。此时,我们可以用一个恒等的 nn.Identity 来替换下这行代码:

class NetWork(nn.Module):

    def __init__(self):
        super(NetWork, self).__init__()

        self.linear1 = nn.Linear(100, 200)
        self.linear2 = nn.Linear(200, 400)
        self.relu = nn.ReLU(inplace=True)

        # 增加一个恒等层
        self.shortcut = nn.Identity()


    def forward(self, inputs):
        residual = inputs

        inputs = self.linear1(inputs)
        self.relu(inputs)
        inputs = self.linear2(inputs)
        self.relu(inputs)

        # 残差连接
        inputs += self.shortcut(residual)

        return inputs

我们将 inputs += residual 替换为 inputs += self.shortcut(residual),首先这两个是等价的,再次的话,后者可能更能直接清楚看到是在做残差连接。

2. 增减网络

增减网络也是我们使用较多的一个场景。比如:一个 bert 预训练模型,我们想去掉其中的某些层,使之不发生作用。如下示例代码:

class NetWork(nn.Module):

    def __init__(self):
        super(NetWork, self).__init__()

        self.linear1 = nn.Linear(100, 100)
        self.linear2 = nn.Linear(100, 100)
        self.relu = nn.ReLU(inplace=True)

        # 增加一个恒等层
        self.shortcut = nn.Identity()


    def forward(self, inputs):
        residual = inputs

        inputs = self.linear1(inputs)
        self.relu(inputs)
        inputs = self.linear2(inputs)
        self.relu(inputs)

        # 残差连接
        inputs += self.shortcut(residual)
        return inputs

def test02():

    # 加载预训练模型参数
    estimator = NetWork()
    # 去掉网络中的某些层
    estimator.linear2 = nn.Identity()

    outputs = estimator(torch.randn(5, 100))
    print(outputs.shape)

我们通过将 estimator.linear2 层替换为了不作任何操作的 nn.Identity(),相当于不改变网络结构的情况下,使得某些层失效。但是这里需要注意以下几点:

  1. nn.Identity 默认参数只有 1 个,如果替换的层输入的参数多于 1 个,则是无法直接替换。
  2. nn.Identity 不会改变数据的维度,如果被替换层的前一个层和后一个层输出和输入维度不同,则也无法直接替换。

…后面再补充

未经允许不得转载:一亩三分地 » 恒等映射层 nn.Identity 的作用
评论 (0)

8 + 2 =