AlexNet 是 2012 年 ImageNet 竞赛冠军获得者 Hinton 和他的学生 Alex Krizhevsky 设计的。AlexNet 把CNN 的基本原理应用到了很深很宽的网络中。主要使用到的新技术点如下:
- 使用 ReLU 作为 CNN 的激活函数,并验证其效果在较深的网络超过了 Sigmoid,成功解决了 Sigmoid 在网络较深时的梯度消失问题。虽然 ReLU 激活函数在很久之前就被提出了,但是直到 AlexNet 的出现才将其发扬光大。
- AlexNet 在最后的全连接层使用 Dropout 随机忽略一部分神经元,以避免模型过拟合。Dropout 虽有单独的论文论述,但是 AlexNet 将其实用化,通过实践证实了它的效果。
- 在 CNN 中使用重叠的最大池化。此前 CNN 中普遍使用平均池化,AlexNet 全部使用最大池化,避免平均池化的模糊化效果。并且 AlexNet 中提出让步长比池化核的尺寸小,这样池化层的输出之间会有重叠和覆盖,提升了特征的丰富性。
- 使用 CUDA 加速深度卷积网络的训练,利用GPU强大的并行计算能力,处理神经网络训练时大量的矩阵运算。
- 数据增强,随机地从 256*256 的原始图像中截取 224*224 大小的区域(以及水平翻转的镜像),增加数据量。
1. 网络结构
AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace=True)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace=True)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=9216, out_features=4096, bias=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=4096, out_features=4096, bias=True)
(5): ReLU(inplace=True)
(6): Linear(in_features=4096, out_features=5, bias=True)
)
)
2. 网络训练
我们这里使用的是 AlexNet 在 ImageNet 上预训练的权重,其使用的 mean 和 std 分别为 [0.485, 0.456, 0.406] 和 [0.229, 0.224, 0.225],所以我们对新的数据也使用该均值和标准差规范化。
import torch
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import glob
import matplotlib.pyplot as plt
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def collate_function(batch_data):
batch_images = []
batch_labels = []
for image, label in batch_data:
# 归一化量纲
image = F.to_tensor(image)
# 规范化分布
image = F.normalize(image,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# 设置图片尺寸
image = F.resize(image, size=(224, 224))
batch_images.append(image)
batch_labels.append(label)
batch_images = torch.stack(batch_images, dim=0).to(device)
batch_labels = torch.tensor(batch_labels, device=device)
return batch_images, batch_labels
def train():
# 加载预训练模型
model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
# 修改输出层输出类别数量
model.classifier[-1] = nn.Linear(in_features=4096, out_features=10)
model = model.to(device)
model.train()
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化方法
optimizer = optim.AdamW(model.parameters(), lr=1e-5)
# 加载数据集
train_data = datasets.CIFAR10(root='data', train=True, download=False)
# 训练轮数
epoch_num = 30
# 批量数量
batch_size = 128
# 训练集
dataloader = DataLoader(train_data,
shuffle=True,
batch_size=batch_size,
collate_fn=collate_function)
# 损失变化
losses = []
for epoch in range(epoch_num):
progress = tqdm(range(len(dataloader)))
total_loss = 0.0
total_size = 0.0
right_size = 0.0
for batch_images, batch_labels in dataloader:
outputs = model(batch_images)
loss = criterion(outputs, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 预测标签
pred_labels = torch.argmax(outputs, dim=-1)
total_size += len(batch_labels)
right_size += torch.sum(pred_labels == batch_labels).item()
total_loss += loss.item()
desc = 'epoch %2d loss %.4f acc %.4f' % (epoch + 1, loss.item(), right_size/total_size)
progress.set_description(desc)
progress.update()
desc = 'epoch %2d loss %.4f acc %.4f' % (epoch+1, total_loss, right_size/total_size)
progress.set_description(desc)
progress.close()
# 存储模型
model_path = 'model/alexnet-%02d.bin' % (epoch + 1)
torch.save(model.state_dict(), model_path)
losses.append(total_loss)
# 绘制损失变化曲线
plt.plot(range(epoch_num), losses)
plt.title('Loss Curve')
plt.xticks(range(epoch_num)[::2], range(epoch_num)[::2])
plt.xlim((0, epoch_num))
plt.grid(True)
plt.show()
if __name__ == '__main__':
train()
训练过程输出:

epoch 1 loss 345.0369 acc 0.6966: 100%|██████| 391/391 [01:01<00:00, 6.34it/s] epoch 2 loss 187.5091 acc 0.8333: 100%|██████| 391/391 [01:02<00:00, 6.27it/s] epoch 3 loss 154.3576 acc 0.8627: 100%|██████| 391/391 [01:03<00:00, 6.12it/s] epoch 4 loss 131.4076 acc 0.8831: 100%|██████| 391/391 [01:03<00:00, 6.11it/s] epoch 5 loss 115.3511 acc 0.8974: 100%|██████| 391/391 [01:03<00:00, 6.11it/s] epoch 6 loss 102.1826 acc 0.9083: 100%|██████| 391/391 [01:04<00:00, 6.10it/s] epoch 7 loss 91.1191 acc 0.9193: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 8 loss 81.5136 acc 0.9278: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 9 loss 72.9525 acc 0.9358: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 10 loss 65.9252 acc 0.9429: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 11 loss 58.5516 acc 0.9482: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 12 loss 53.6095 acc 0.9526: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 13 loss 47.0796 acc 0.9585: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 14 loss 42.3943 acc 0.9634: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 15 loss 39.8421 acc 0.9648: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 16 loss 33.9965 acc 0.9708: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 17 loss 30.9497 acc 0.9730: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 18 loss 27.8056 acc 0.9765: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 19 loss 25.4500 acc 0.9786: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 20 loss 23.0980 acc 0.9804: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 21 loss 21.5326 acc 0.9820: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 22 loss 20.3412 acc 0.9831: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 23 loss 17.2332 acc 0.9856: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 24 loss 15.8865 acc 0.9873: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 25 loss 15.2423 acc 0.9875: 100%|███████| 391/391 [01:04<00:00, 6.11it/s] epoch 26 loss 14.0911 acc 0.9890: 100%|███████| 391/391 [01:03<00:00, 6.11it/s] epoch 27 loss 12.7544 acc 0.9898: 100%|███████| 391/391 [01:04<00:00, 6.09it/s] epoch 28 loss 11.8810 acc 0.9907: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 29 loss 10.7400 acc 0.9917: 100%|███████| 391/391 [01:04<00:00, 6.10it/s] epoch 30 loss 10.9235 acc 0.9908: 100%|███████| 391/391 [01:04<00:00, 6.11it/s]
3. 网络评估
def evaluate(model_path):
model = models.AlexNet(num_classes=10)
model.load_state_dict(torch.load(model_path))
model = model.to(device)
model.eval()
with torch.no_grad():
test_data = datasets.CIFAR10(root='data', train=False, download=False)
dataloader = DataLoader(test_data,
shuffle=True,
batch_size=128,
collate_fn=collate_function)
total, right = 0.0, 0.0
for batch_images, batch_labels in dataloader:
outputs = model(batch_images)
y_pred = torch.argmax(outputs, dim=-1)
total += len(batch_labels)
right += torch.sum(y_pred == batch_labels)
print(model_path)
print('Test Acc %.4f %d/%d' % (right / total, right, total))
def evaluate_model():
model_list = glob.glob('model/*.bin')
for model in model_list:
evaluate(model)
print('-' * 30)
if __name__ == '__main__':
evaluate_model()
评估过程输出:
model/alexnet-01.bin Test Acc 0.8324 8324/10000 ------------------------------ model/alexnet-02.bin Test Acc 0.8612 8612/10000 ------------------------------ model/alexnet-03.bin Test Acc 0.8765 8765/10000 ------------------------------ model/alexnet-04.bin Test Acc 0.8833 8833/10000 ------------------------------ model/alexnet-05.bin Test Acc 0.8931 8931/10000 ------------------------------ model/alexnet-06.bin Test Acc 0.8979 8979/10000 ------------------------------ model/alexnet-07.bin Test Acc 0.9034 9034/10000 ------------------------------ model/alexnet-08.bin Test Acc 0.9061 9061/10000 ------------------------------ model/alexnet-09.bin Test Acc 0.9022 9022/10000 ------------------------------ model/alexnet-10.bin Test Acc 0.9086 9086/10000 ------------------------------ model/alexnet-11.bin Test Acc 0.9110 9110/10000 ------------------------------ model/alexnet-12.bin Test Acc 0.9146 9146/10000 ------------------------------ model/alexnet-13.bin Test Acc 0.9146 9146/10000 ------------------------------ model/alexnet-14.bin Test Acc 0.9162 9162/10000 ------------------------------ model/alexnet-15.bin Test Acc 0.9165 9165/10000 ------------------------------ model/alexnet-16.bin Test Acc 0.9145 9145/10000 ------------------------------ model/alexnet-17.bin Test Acc 0.9150 9150/10000 ------------------------------ model/alexnet-18.bin Test Acc 0.9185 9185/10000 ------------------------------ model/alexnet-19.bin Test Acc 0.9180 9180/10000 ------------------------------ model/alexnet-20.bin Test Acc 0.9180 9180/10000 ------------------------------ model/alexnet-21.bin Test Acc 0.9191 9191/10000 ------------------------------ model/alexnet-22.bin Test Acc 0.9176 9176/10000 ------------------------------ model/alexnet-23.bin Test Acc 0.9190 9190/10000 ------------------------------ model/alexnet-24.bin Test Acc 0.9197 9197/10000 ------------------------------ model/alexnet-25.bin Test Acc 0.9168 9168/10000 ------------------------------ model/alexnet-26.bin Test Acc 0.9186 9186/10000 ------------------------------ model/alexnet-27.bin Test Acc 0.9185 9185/10000 ------------------------------ model/alexnet-28.bin Test Acc 0.9185 9185/10000 ------------------------------ model/alexnet-29.bin Test Acc 0.9203 9203/10000 ------------------------------ model/alexnet-30.bin Test Acc 0.9175 9175/10000 ------------------------------

冀公网安备13050302001966号