AlexNet 是 2012 年 ImageNet 竞赛冠军获得者 Hinton 和他的学生 Alex Krizhevsky 设计的。AlexNet 把CNN 的基本原理应用到了很深很宽的网络中。主要使用到的新技术点如下:
- 使用 ReLU 作为 CNN 的激活函数,并验证其效果在较深的网络超过了 Sigmoid,成功解决了 Sigmoid 在网络较深时的梯度消失问题。虽然 ReLU 激活函数在很久之前就被提出了,但是直到 AlexNet 的出现才将其发扬光大。
- AlexNet 在最后的全连接层使用 Dropout 随机忽略一部分神经元,以避免模型过拟合。Dropout 虽有单独的论文论述,但是 AlexNet 将其实用化,通过实践证实了它的效果。
- 在 CNN 中使用重叠的最大池化。此前 CNN 中普遍使用平均池化,AlexNet 全部使用最大池化,避免平均池化的模糊化效果。并且 AlexNet 中提出让步长比池化核的尺寸小,这样池化层的输出之间会有重叠和覆盖,提升了特征的丰富性。
- 使用 CUDA 加速深度卷积网络的训练,利用GPU强大的并行计算能力,处理神经网络训练时大量的矩阵运算。
- 数据增强,随机地从 256*256 的原始图像中截取 224*224 大小的区域(以及水平翻转的镜像),增加数据量。
1. 网络结构
AlexNet( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)) (1): ReLU(inplace=True) (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (4): ReLU(inplace=True) (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (7): ReLU(inplace=True) (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (9): ReLU(inplace=True) (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(output_size=(6, 6)) (classifier): Sequential( (0): Dropout(p=0.5, inplace=False) (1): Linear(in_features=9216, out_features=4096, bias=True) (2): ReLU(inplace=True) (3): Dropout(p=0.5, inplace=False) (4): Linear(in_features=4096, out_features=4096, bias=True) (5): ReLU(inplace=True) (6): Linear(in_features=4096, out_features=5, bias=True) ) )
2. 网络训练
我们这里使用的是 AlexNet 在 ImageNet 上预训练的权重,其使用的 mean 和 std 分别为 [0.485, 0.456, 0.406] 和 [0.229, 0.224, 0.225],所以我们对新的数据也使用该均值和标准差规范化。
import torch import torchvision.models as models import torchvision.datasets as datasets import torchvision.transforms.functional as F from torch.utils.data import DataLoader import torch.optim.lr_scheduler as lr_scheduler from tqdm import tqdm import torch.nn as nn import torch.optim as optim import glob import matplotlib.pyplot as plt import ssl ssl._create_default_https_context = ssl._create_unverified_context device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') def collate_function(batch_data): batch_images = [] batch_labels = [] for image, label in batch_data: # 归一化量纲 image = F.to_tensor(image) # 规范化分布 image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 设置图片尺寸 image = F.resize(image, size=(224, 224)) batch_images.append(image) batch_labels.append(label) batch_images = torch.stack(batch_images, dim=0).to(device) batch_labels = torch.tensor(batch_labels, device=device) return batch_images, batch_labels def train(): # 加载预训练模型 model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1) # 修改输出层输出类别数量 model.classifier[-1] = nn.Linear(in_features=4096, out_features=10) model = model.to(device) model.train() # 损失函数 criterion = nn.CrossEntropyLoss() # 优化方法 optimizer = optim.AdamW(model.parameters(), lr=1e-5) # 加载数据集 train_data = datasets.CIFAR10(root='data', train=True, download=False) # 训练轮数 epoch_num = 30 # 批量数量 batch_size = 128 # 训练集 dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size, collate_fn=collate_function) # 损失变化 losses = [] for epoch in range(epoch_num): progress = tqdm(range(len(dataloader))) total_loss = 0.0 total_size = 0.0 right_size = 0.0 for batch_images, batch_labels in dataloader: outputs = model(batch_images) loss = criterion(outputs, batch_labels) optimizer.zero_grad() loss.backward() optimizer.step() # 预测标签 pred_labels = torch.argmax(outputs, dim=-1) total_size += len(batch_labels) right_size += torch.sum(pred_labels == batch_labels).item() total_loss += loss.item() desc = 'epoch %2d loss %.4f acc %.4f' % (epoch + 1, loss.item(), right_size/total_size) progress.set_description(desc) progress.update() desc = 'epoch %2d loss %.4f acc %.4f' % (epoch+1, total_loss, right_size/total_size) progress.set_description(desc) progress.close() # 存储模型 model_path = 'model/alexnet-%02d.bin' % (epoch + 1) torch.save(model.state_dict(), model_path) losses.append(total_loss) # 绘制损失变化曲线 plt.plot(range(epoch_num), losses) plt.title('Loss Curve') plt.xticks(range(epoch_num)[::2], range(epoch_num)[::2]) plt.xlim((0, epoch_num)) plt.grid(True) plt.show() if __name__ == '__main__': train()
训练过程输出:
epoch 1 loss 345.0369 acc 0.6966: 100%|██████| 391/391 [01:01<00:00, 6.34it/s] epoch 2 loss 187.5091 acc 0.8333: 100%|██████| 391/391 [01:02<00:00, 6.27it/s] epoch 3 loss 154.3576 acc 0.8627: 100%|██████| 391/391 [01:03<00:00, 6.12it/s] epoch 4 loss 131.4076 acc 0.8831: 100%|██████| 391/391 [01:03<00:00, 6.11it/s] epoch 5 loss 115.3511 acc 0.8974: 100%|██████| 391/391 [01:03<00:00, 6.11it/s] epoch 6 loss 102.1826 acc 0.9083: 100%|██████| 391/391 [01:04<00:00, 6.10it/s] epoch 7 loss 91.1191 acc 0.9193: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 8 loss 81.5136 acc 0.9278: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 9 loss 72.9525 acc 0.9358: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 10 loss 65.9252 acc 0.9429: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 11 loss 58.5516 acc 0.9482: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 12 loss 53.6095 acc 0.9526: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 13 loss 47.0796 acc 0.9585: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 14 loss 42.3943 acc 0.9634: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 15 loss 39.8421 acc 0.9648: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 16 loss 33.9965 acc 0.9708: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 17 loss 30.9497 acc 0.9730: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 18 loss 27.8056 acc 0.9765: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 19 loss 25.4500 acc 0.9786: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 20 loss 23.0980 acc 0.9804: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 21 loss 21.5326 acc 0.9820: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 22 loss 20.3412 acc 0.9831: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 23 loss 17.2332 acc 0.9856: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 24 loss 15.8865 acc 0.9873: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 25 loss 15.2423 acc 0.9875: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 26 loss 14.0911 acc 0.9890: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 27 loss 12.7544 acc 0.9898: 100%|███████| 391/391 [01:04<00:00, 6.09it/s] epoch 28 loss 11.8810 acc 0.9907: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 29 loss 10.7400 acc 0.9917: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 30 loss 10.9235 acc 0.9908: 100%|███████| 391/391 [01:04<00:00, 6.11it/s]
3. 网络评估
def evaluate(model_path): model = models.AlexNet(num_classes=10) model.load_state_dict(torch.load(model_path)) model = model.to(device) model.eval() with torch.no_grad(): test_data = datasets.CIFAR10(root='data', train=False, download=False) dataloader = DataLoader(test_data, shuffle=True, batch_size=128, collate_fn=collate_function) total, right = 0.0, 0.0 for batch_images, batch_labels in dataloader: outputs = model(batch_images) y_pred = torch.argmax(outputs, dim=-1) total += len(batch_labels) right += torch.sum(y_pred == batch_labels) print(model_path) print('Test Acc %.4f %d/%d' % (right / total, right, total)) def evaluate_model(): model_list = glob.glob('model/*.bin') for model in model_list: evaluate(model) print('-' * 30) if __name__ == '__main__': evaluate_model()
评估过程输出:
model/alexnet-01.bin Test Acc 0.8324 8324/10000 ------------------------------ model/alexnet-02.bin Test Acc 0.8612 8612/10000 ------------------------------ model/alexnet-03.bin Test Acc 0.8765 8765/10000 ------------------------------ model/alexnet-04.bin Test Acc 0.8833 8833/10000 ------------------------------ model/alexnet-05.bin Test Acc 0.8931 8931/10000 ------------------------------ model/alexnet-06.bin Test Acc 0.8979 8979/10000 ------------------------------ model/alexnet-07.bin Test Acc 0.9034 9034/10000 ------------------------------ model/alexnet-08.bin Test Acc 0.9061 9061/10000 ------------------------------ model/alexnet-09.bin Test Acc 0.9022 9022/10000 ------------------------------ model/alexnet-10.bin Test Acc 0.9086 9086/10000 ------------------------------ model/alexnet-11.bin Test Acc 0.9110 9110/10000 ------------------------------ model/alexnet-12.bin Test Acc 0.9146 9146/10000 ------------------------------ model/alexnet-13.bin Test Acc 0.9146 9146/10000 ------------------------------ model/alexnet-14.bin Test Acc 0.9162 9162/10000 ------------------------------ model/alexnet-15.bin Test Acc 0.9165 9165/10000 ------------------------------ model/alexnet-16.bin Test Acc 0.9145 9145/10000 ------------------------------ model/alexnet-17.bin Test Acc 0.9150 9150/10000 ------------------------------ model/alexnet-18.bin Test Acc 0.9185 9185/10000 ------------------------------ model/alexnet-19.bin Test Acc 0.9180 9180/10000 ------------------------------ model/alexnet-20.bin Test Acc 0.9180 9180/10000 ------------------------------ model/alexnet-21.bin Test Acc 0.9191 9191/10000 ------------------------------ model/alexnet-22.bin Test Acc 0.9176 9176/10000 ------------------------------ model/alexnet-23.bin Test Acc 0.9190 9190/10000 ------------------------------ model/alexnet-24.bin Test Acc 0.9197 9197/10000 ------------------------------ model/alexnet-25.bin Test Acc 0.9168 9168/10000 ------------------------------ model/alexnet-26.bin Test Acc 0.9186 9186/10000 ------------------------------ model/alexnet-27.bin Test Acc 0.9185 9185/10000 ------------------------------ model/alexnet-28.bin Test Acc 0.9185 9185/10000 ------------------------------ model/alexnet-29.bin Test Acc 0.9203 9203/10000 ------------------------------ model/alexnet-30.bin Test Acc 0.9175 9175/10000 ------------------------------