池化层 (Pooling) 降低维度, 缩减模型大小,提高计算速度. 另外一个作用可以缓解卷积层对位置的敏感性.
池化层主要有两种:
- 最大池化
- 平均池化
1. 池化层计算
最大池化:
- max(0, 1, 3, 4)
- max(1, 2, 4, 5)
- max(3, 4, 6, 7)
- max(4, 5, 7, 8)
平均池化:
- mean(0, 1, 3, 4)
- mean(1, 2, 4, 5)
- mean(3, 4, 6, 7)
- mean(4, 5, 7, 8)
2. Stride
最大池化:
- max(0, 1, 4, 5)
- max(2, 3, 6, 7)
- max(8, 9, 12, 13)
- max(10, 11, 14, 15)
平均池化:
- mean(0, 1, 4, 5)
- mean(2, 3, 6, 7)
- mean(8, 9, 12, 13)
- mean(10, 11, 14, 15)
3. Padding
最大池化:
- max(0, 0, 0, 0)
- max(0, 0, 0, 1)
- max(0, 0, 1, 2)
- max(0, 0, 2, 0)
- … 以此类推
平均池化:
- mean(0, 0, 0, 0)
- mean(0, 0, 0, 1)
- mean(0, 0, 1, 2)
- mean(0, 0, 2, 0)
- … 以此类推
4. 多通道池化计算
在处理多通道输入数据时,池化层对每个输入通道分别池化,而不是像卷积层那样将各个通道的输入相加。这意味着池化层的输出和输入的通道数是相等。
5. PyTorch 池化 API 使用
import torch import torch.nn as nn # 1. API 基本使用 def test01(): inputs = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]).float() inputs = inputs.unsqueeze(0).unsqueeze(0) # 1. 最大池化 # 输入形状: (N, C, H, W) polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=0) output = polling(inputs) print(output) # 2. 平均池化 polling = nn.AvgPool2d(kernel_size=2, stride=1, padding=0) output = polling(inputs) print(output) # 2. stride 步长 def test02(): inputs = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]).float() inputs = inputs.unsqueeze(0).unsqueeze(0) # 1. 最大池化 polling = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) output = polling(inputs) print(output) # 2. 平均池化 polling = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) output = polling(inputs) print(output) # 3. padding 填充 def test03(): inputs = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]).float() inputs = inputs.unsqueeze(0).unsqueeze(0) # 1. 最大池化 polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=1) output = polling(inputs) print(output) # 2. 平均池化 polling = nn.AvgPool2d(kernel_size=2, stride=1, padding=1) output = polling(inputs) print(output) # 4. 多通道池化 def test04(): inputs = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[10, 20, 30], [40, 50, 60], [70, 80, 90]], [[11, 22, 33], [44, 55, 66], [77, 88, 99]]]).float() inputs = inputs.unsqueeze(0) # 最大池化 polling = nn.MaxPool2d(kernel_size=2, stride=1, padding=0) output = polling(inputs) print(output) if __name__ == '__main__': test04()