nn.Identity 是 PyTorch 中的一个层,它的作用是不做任何改变地传递输入数据。它在前向传播时会返回输入数据本身,而不对其进行任何处理或变换。
class Identity(Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(Identity, self).__init__()
def forward(self, input: Tensor) -> Tensor:
return input
这样的一个实现,在大多数人看来似乎没有什么用,我根据自己的使用经验总结了下两个场景下的作用:
- 提高代码可读性
- 增减网络,达到简化网络的作用
接下来,我们分别看下这两个场景下 nn.Identity 是如何发挥作用的。
1. 代码可读性
下面的代码,将输入做各种变换之后,并跨层进行残差连接:
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.linear1 = nn.Linear(100, 200)
self.linear2 = nn.Linear(200, 400)
self.relu = nn.ReLU(inplace=True)
def forward(self, inputs):
residual = inputs
inputs = self.linear1(inputs)
self.relu(inputs)
inputs = self.linear2(inputs)
self.relu(inputs)
# 残差连接
inputs += residual
return inputs
残差连接部分是:inputs += residual,这个代码看起来没什么,如果没有注释的话,可能并不能知道这行代码的作用,可读性稍微差一些。此时,我们可以用一个恒等的 nn.Identity 来替换下这行代码:
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.linear1 = nn.Linear(100, 200)
self.linear2 = nn.Linear(200, 400)
self.relu = nn.ReLU(inplace=True)
# 增加一个恒等层
self.shortcut = nn.Identity()
def forward(self, inputs):
residual = inputs
inputs = self.linear1(inputs)
self.relu(inputs)
inputs = self.linear2(inputs)
self.relu(inputs)
# 残差连接
inputs += self.shortcut(residual)
return inputs
我们将 inputs += residual 替换为 inputs += self.shortcut(residual),首先这两个是等价的,再次的话,后者可能更能直接清楚看到是在做残差连接。
2. 增减网络
增减网络也是我们使用较多的一个场景。比如:一个 bert 预训练模型,我们想去掉其中的某些层,使之不发生作用。如下示例代码:
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.linear1 = nn.Linear(100, 100)
self.linear2 = nn.Linear(100, 100)
self.relu = nn.ReLU(inplace=True)
# 增加一个恒等层
self.shortcut = nn.Identity()
def forward(self, inputs):
residual = inputs
inputs = self.linear1(inputs)
self.relu(inputs)
inputs = self.linear2(inputs)
self.relu(inputs)
# 残差连接
inputs += self.shortcut(residual)
return inputs
def test02():
# 加载预训练模型参数
estimator = NetWork()
# 去掉网络中的某些层
estimator.linear2 = nn.Identity()
outputs = estimator(torch.randn(5, 100))
print(outputs.shape)
我们通过将 estimator.linear2 层替换为了不作任何操作的 nn.Identity(),相当于不改变网络结构的情况下,使得某些层失效。但是这里需要注意以下几点:
- nn.Identity 默认参数只有 1 个,如果替换的层输入的参数多于 1 个,则是无法直接替换。
- nn.Identity 不会改变数据的维度,如果被替换层的前一个层和后一个层输出和输入维度不同,则也无法直接替换。

冀公网安备13050302001966号