池化层(Pooling Layer)是卷积神经网络(CNN)中的一种层结构,主要用于对特征图进行降采样,以减小数据的空间尺寸,降低模型的计算量和参数数量,从而在一定程度上防止过拟合。池化层的主要作用包括:
- 降采样:通过对输入特征图进行下采样,减少数据的空间尺寸,从而降低模型的计算复杂度和参数数量,有助于防止过拟合。
- 特征选择:池化操作(如最大池化或平均池化)在每个池化窗口中提取主要特征,保留关键信息,忽略不重要的细节。
- 增强不变性:池化层使模型对输入数据的平移、旋转和缩放等变换更加鲁棒,提高模型的泛化能力。
常见的池化操作有以下几种:
- 最大池化(Max Pooling):将输入特征图划分为若干个矩形区域,对每个子区域输出最大值。
- 平均池化(Average Pooling):计算图像区域的平均值作为该区域池化后的值。
注意:池化层执行固定的数学操作,以对特征图进行降采样。这些操作不涉及权重或偏置等可训练参数,因此在模型训练过程中,池化层没有参数需要更新。
1. 单通道计算

- 最大池化:max(0, 1, 3, 4)、max(1, 2, 4, 5)、max(3, 4, 6, 7)、max(4, 5, 7, 8)
- 平均池化:mean(0, 1, 3, 4)、mean(1, 2, 4, 5)、mean(3, 4, 6, 7)、mean(4, 5, 7, 8)
2. Stride

- 最大池化:max(0, 1, 4, 5)、max(2, 3, 6, 7)、max(8, 9, 12, 13)、max(10, 11, 14, 15)
- 平均池化:mean(0, 1, 4, 5)、mean(2, 3, 6, 7)、mean(8, 9, 12, 13)、mean(10, 11, 14, 15)
3. Padding

- 最大池化:max(0, 0, 0, 0)、max(0, 0, 0, 1)、max(0, 0, 1, 2)、max(0, 0, 2, 0) … 以此类推
- 平均池化:mean(0, 0, 0, 0)、mean(0, 0, 0, 1)、mean(0, 0, 1, 2)、mean(0, 0, 2, 0) … 以此类推
经过池化之后特征图的维度计算可以用下面的公式:

假设输入图像:
- input width = 3,intput height = 3
- padding width = 1,padding height = 1
- Stride = 1
经过 filter width = 2,filter height = 2 的池化计算之后,特征图维度为:(3+2*1-2)/1=4
。
4. 多通道计算
在处理多通道输入数据时,池化层对每个输入通道分别池化,而不是像卷积层那样将各个通道的输入相加。这意味着池化层的输出和输入的通道数是相等。

5. 池化 API 使用
import torch import torch.nn as nn # 1. API 基本使用 def test01(): inputs = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]).float() inputs = inputs.unsqueeze(0).unsqueeze(0) # 1. 最大池化 # 输入形状: (N, C, H, W) polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=0) output = polling(inputs) print(output) # 2. 平均池化 polling = nn.AvgPool2d(kernel_size=2, stride=1, padding=0) output = polling(inputs) print(output) # 2. stride 步长 def test02(): inputs = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]).float() inputs = inputs.unsqueeze(0).unsqueeze(0) # 1. 最大池化 polling = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) output = polling(inputs) print(output) # 2. 平均池化 polling = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) output = polling(inputs) print(output) # 3. padding 填充 def test03(): inputs = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]).float() inputs = inputs.unsqueeze(0).unsqueeze(0) # 1. 最大池化 polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=1) output = polling(inputs) print(output) # 2. 平均池化 polling = nn.AvgPool2d(kernel_size=2, stride=1, padding=1) output = polling(inputs) print(output) # 4. 多通道池化 def test04(): inputs = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[10, 20, 30], [40, 50, 60], [70, 80, 90]], [[11, 22, 33], [44, 55, 66], [77, 88, 99]]]).float() inputs = inputs.unsqueeze(0) # 最大池化 polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=0) output = polling(inputs) print(output) if __name__ == '__main__': test04()