nn.Identity
是 PyTorch 中的一个层,它的作用是不做任何改变地传递输入数据。它在前向传播时会返回输入数据本身,而不对其进行任何处理或变换。
class Identity(Module): def __init__(self, *args: Any, **kwargs: Any) -> None: super(Identity, self).__init__() def forward(self, input: Tensor) -> Tensor: return input
这样的一个实现,在大多数人看来似乎没有什么用,我根据自己的使用经验总结了下两个场景下的作用:
- 提高代码可读性
- 增减网络,达到简化网络的作用
接下来,我们分别看下这两个场景下 nn.Identity 是如何发挥作用的。
1. 代码可读性
下面的代码,将输入做各种变换之后,并跨层进行残差连接:
class NetWork(nn.Module): def __init__(self): super(NetWork, self).__init__() self.linear1 = nn.Linear(100, 200) self.linear2 = nn.Linear(200, 400) self.relu = nn.ReLU(inplace=True) def forward(self, inputs): residual = inputs inputs = self.linear1(inputs) self.relu(inputs) inputs = self.linear2(inputs) self.relu(inputs) # 残差连接 inputs += residual return inputs
残差连接部分是:inputs += residual,这个代码看起来没什么,如果没有注释的话,可能并不能知道这行代码的作用,可读性稍微差一些。此时,我们可以用一个恒等的 nn.Identity 来替换下这行代码:
class NetWork(nn.Module): def __init__(self): super(NetWork, self).__init__() self.linear1 = nn.Linear(100, 200) self.linear2 = nn.Linear(200, 400) self.relu = nn.ReLU(inplace=True) # 增加一个恒等层 self.shortcut = nn.Identity() def forward(self, inputs): residual = inputs inputs = self.linear1(inputs) self.relu(inputs) inputs = self.linear2(inputs) self.relu(inputs) # 残差连接 inputs += self.shortcut(residual) return inputs
我们将 inputs += residual 替换为 inputs += self.shortcut(residual),首先这两个是等价的,再次的话,后者可能更能直接清楚看到是在做残差连接。
2. 增减网络
增减网络也是我们使用较多的一个场景。比如:一个 bert 预训练模型,我们想去掉其中的某些层,使之不发生作用。如下示例代码:
class NetWork(nn.Module): def __init__(self): super(NetWork, self).__init__() self.linear1 = nn.Linear(100, 100) self.linear2 = nn.Linear(100, 100) self.relu = nn.ReLU(inplace=True) # 增加一个恒等层 self.shortcut = nn.Identity() def forward(self, inputs): residual = inputs inputs = self.linear1(inputs) self.relu(inputs) inputs = self.linear2(inputs) self.relu(inputs) # 残差连接 inputs += self.shortcut(residual) return inputs def test02(): # 加载预训练模型参数 estimator = NetWork() # 去掉网络中的某些层 estimator.linear2 = nn.Identity() outputs = estimator(torch.randn(5, 100)) print(outputs.shape)
我们通过将 estimator.linear2 层替换为了不作任何操作的 nn.Identity(),相当于不改变网络结构的情况下,使得某些层失效。但是这里需要注意以下几点:
- nn.Identity 默认参数只有 1 个,如果替换的层输入的参数多于 1 个,则是无法直接替换。
- nn.Identity 不会改变数据的维度,如果被替换层的前一个层和后一个层输出和输入维度不同,则也无法直接替换。