nn.Identity 在一些编码中会看到,它的实现代码非常简单,仅仅是将输入的内容原封不动的输出。请看它在 PyTorch 中的实现代码:
class Identity(Module): r"""A placeholder identity operator that is argument-insensitive. Args: args: any argument (unused) kwargs: any keyword argument (unused) Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. Examples:: >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 20]) """ def __init__(self, *args: Any, **kwargs: Any) -> None: super(Identity, self).__init__() def forward(self, input: Tensor) -> Tensor: return input
其使用的时候,我们一般按照下面的方式使用:
import torch import torch.nn as nn def test01(): a = torch.tensor([100, 200]) b = nn.Identity()(a) print('a =', a, id(a)) print('b =', b, id(b)) if __name__ == '__main__': test01()
程序输出结果:
a = tensor([1000, 200]) 140279002222128 b = tensor([1000, 200]) 140279002222128
从程序的执行结果,可以看到,nn.Identity 仅仅是输入张量,返回张量本身而已。这样的一个实现,在大多数人看来似乎没有什么用,我根据自己的使用经验总结了下两个场景下的作用:
- 提高代码可读性
- 增减网络,达到简化网络的作用
接下来,我们分别看下这两个场景下 nn.Identity 是如何发挥作用的。
1. 代码可读性
下面的代码,将输入做各种变换之后,并跨层进行残差连接:
class NetWork(nn.Module): def __init__(self): super(NetWork, self).__init__() self.linear1 = nn.Linear(100, 200) self.linear2 = nn.Linear(200, 400) self.relu = nn.ReLU(inplace=True) def forward(self, inputs): residual = inputs inputs = self.linear1(inputs) self.relu(inputs) inputs = self.linear2(inputs) self.relu(inputs) # 残差连接 inputs += residual return inputs
残差连接部分是:inputs += residual,这个代码看起来没什么,如果没有注释的话,可能并不能知道这行代码的作用,可读性稍微差一些。此时,我们可以用一个恒等的 nn.Identity 来替换下这行代码:
class NetWork(nn.Module): def __init__(self): super(NetWork, self).__init__() self.linear1 = nn.Linear(100, 200) self.linear2 = nn.Linear(200, 400) self.relu = nn.ReLU(inplace=True) # 增加一个恒等层 self.shortcut = nn.Identity() def forward(self, inputs): residual = inputs inputs = self.linear1(inputs) self.relu(inputs) inputs = self.linear2(inputs) self.relu(inputs) # 残差连接 inputs += self.shortcut(residual) return inputs
我们将 inputs += residual 替换为 inputs += self.shortcut(residual),首先这两个是等价的,再次的话,后者可能更能直接清楚看到是在做残差连接。
2. 增减网络
增减网络也是我们使用较多的一个场景。比如:一个 bert 预训练模型,我们想去掉其中的某些层,使之不发生作用。如下示例代码:
class NetWork(nn.Module): def __init__(self): super(NetWork, self).__init__() self.linear1 = nn.Linear(100, 100) self.linear2 = nn.Linear(100, 100) self.relu = nn.ReLU(inplace=True) # 增加一个恒等层 self.shortcut = nn.Identity() def forward(self, inputs): residual = inputs inputs = self.linear1(inputs) self.relu(inputs) inputs = self.linear2(inputs) self.relu(inputs) # 残差连接 inputs += self.shortcut(residual) return inputs def test02(): # 加载预训练模型参数 estimator = NetWork() # 去掉网络中的某些层 estimator.linear2 = nn.Identity() outputs = estimator(torch.randn(5, 100)) print(outputs.shape)
我们通过将 estimator.linear2 层替换为了不作任何操作的 nn.Identity(),相当于不改变网络结构的情况下,使得某些层失效。但是这里需要注意以下几点:
- nn.Identity 默认参数只有 1 个,如果替换的层输入的参数多于 1 个,则是无法直接替换。
- nn.Identity 不会改变数据的维度,如果被替换层的前一个层和后一个层输出和输入维度不同,则也无法直接替换。
…后面再补充