AlexNet CIFAR10 图像分类

AlexNet 是 2012 年 ImageNet 竞赛冠军获得者 Hinton 和他的学生 Alex Krizhevsky 设计的。AlexNet 把CNN 的基本原理应用到了很深很宽的网络中。主要使用到的新技术点如下:

  1. 使用 ReLU 作为 CNN 的激活函数,并验证其效果在较深的网络超过了 Sigmoid,成功解决了 Sigmoid 在网络较深时的梯度消失问题。虽然 ReLU 激活函数在很久之前就被提出了,但是直到 AlexNet 的出现才将其发扬光大。
  2. AlexNet 在最后的全连接层使用 Dropout 随机忽略一部分神经元,以避免模型过拟合。Dropout 虽有单独的论文论述,但是 AlexNet 将其实用化,通过实践证实了它的效果。
  3. 在 CNN 中使用重叠的最大池化。此前 CNN 中普遍使用平均池化,AlexNet 全部使用最大池化,避免平均池化的模糊化效果。并且 AlexNet 中提出让步长比池化核的尺寸小,这样池化层的输出之间会有重叠和覆盖,提升了特征的丰富性。
  4. 使用 CUDA 加速深度卷积网络的训练,利用GPU强大的并行计算能力,处理神经网络训练时大量的矩阵运算。
  5. 数据增强,随机地从 256*256 的原始图像中截取 224*224 大小的区域(以及水平翻转的镜像),增加数据量。

1. 网络结构

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=5, bias=True)
  )
)

2. 网络训练

我们这里使用的是 AlexNet 在 ImageNet 上预训练的权重,其使用的 mean 和 std 分别为 [0.485, 0.456, 0.406] 和 [0.229, 0.224, 0.225],所以我们对新的数据也使用该均值和标准差规范化。

import torch
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import glob
import matplotlib.pyplot as plt
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def collate_function(batch_data):

    batch_images = []
    batch_labels = []
    for image, label in batch_data:
        # 归一化量纲
        image = F.to_tensor(image)
        # 规范化分布
        image = F.normalize(image,
                            mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
        # 设置图片尺寸
        image = F.resize(image, size=(224, 224))
        batch_images.append(image)
        batch_labels.append(label)

    batch_images = torch.stack(batch_images, dim=0).to(device)
    batch_labels = torch.tensor(batch_labels, device=device)

    return batch_images, batch_labels

def train():

    # 加载预训练模型
    model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
    # 修改输出层输出类别数量
    model.classifier[-1] = nn.Linear(in_features=4096, out_features=10)
    model = model.to(device)
    model.train()
    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # 优化方法
    optimizer = optim.AdamW(model.parameters(), lr=1e-5)
    # 加载数据集
    train_data = datasets.CIFAR10(root='data', train=True, download=False)
    # 训练轮数
    epoch_num = 30
    # 批量数量
    batch_size = 128
    # 训练集
    dataloader = DataLoader(train_data,
                            shuffle=True,
                            batch_size=batch_size,
                            collate_fn=collate_function)
    # 损失变化
    losses = []
    for epoch in range(epoch_num):

        progress = tqdm(range(len(dataloader)))
        total_loss = 0.0
        total_size = 0.0
        right_size = 0.0

        for batch_images, batch_labels in dataloader:
            outputs = model(batch_images)
            loss = criterion(outputs, batch_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 预测标签
            pred_labels = torch.argmax(outputs, dim=-1)
            total_size += len(batch_labels)
            right_size += torch.sum(pred_labels == batch_labels).item()

            total_loss += loss.item()
            desc = 'epoch %2d loss %.4f acc %.4f' % (epoch + 1, loss.item(), right_size/total_size)
            progress.set_description(desc)
            progress.update()

        desc = 'epoch %2d loss %.4f acc %.4f' % (epoch+1, total_loss, right_size/total_size)
        progress.set_description(desc)
        progress.close()

        # 存储模型
        model_path = 'model/alexnet-%02d.bin' % (epoch + 1)
        torch.save(model.state_dict(), model_path)
        losses.append(total_loss)

    # 绘制损失变化曲线
    plt.plot(range(epoch_num), losses)
    plt.title('Loss Curve')
    plt.xticks(range(epoch_num)[::2], range(epoch_num)[::2])
    plt.xlim((0, epoch_num))
    plt.grid(True)
    plt.show()


if __name__ == '__main__':
    train()

训练过程输出:

epoch  1 loss 345.0369 acc 0.6966: 100%|██████| 391/391 [01:01<00:00,  6.34it/s]
epoch  2 loss 187.5091 acc 0.8333: 100%|██████| 391/391 [01:02<00:00,  6.27it/s]
epoch  3 loss 154.3576 acc 0.8627: 100%|██████| 391/391 [01:03<00:00,  6.12it/s]
epoch  4 loss 131.4076 acc 0.8831: 100%|██████| 391/391 [01:03<00:00,  6.11it/s]
epoch  5 loss 115.3511 acc 0.8974: 100%|██████| 391/391 [01:03<00:00,  6.11it/s]
epoch  6 loss 102.1826 acc 0.9083: 100%|██████| 391/391 [01:04<00:00,  6.10it/s]
epoch  7 loss 91.1191 acc 0.9193: 100%|███████| 391/391 [01:04<00:00,  6.10it/s]
epoch  8 loss 81.5136 acc 0.9278: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]
epoch  9 loss 72.9525 acc 0.9358: 100%|███████| 391/391 [01:04<00:00,  6.10it/s]
epoch 10 loss 65.9252 acc 0.9429: 100%|███████| 391/391 [01:03<00:00,  6.11it/s]
epoch 11 loss 58.5516 acc 0.9482: 100%|███████| 391/391 [01:03<00:00,  6.11it/s]
epoch 12 loss 53.6095 acc 0.9526: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]
epoch 13 loss 47.0796 acc 0.9585: 100%|███████| 391/391 [01:03<00:00,  6.11it/s]
epoch 14 loss 42.3943 acc 0.9634: 100%|███████| 391/391 [01:04<00:00,  6.10it/s]
epoch 15 loss 39.8421 acc 0.9648: 100%|███████| 391/391 [01:03<00:00,  6.11it/s]
epoch 16 loss 33.9965 acc 0.9708: 100%|███████| 391/391 [01:03<00:00,  6.11it/s]
epoch 17 loss 30.9497 acc 0.9730: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]
epoch 18 loss 27.8056 acc 0.9765: 100%|███████| 391/391 [01:04<00:00,  6.10it/s]
epoch 19 loss 25.4500 acc 0.9786: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]
epoch 20 loss 23.0980 acc 0.9804: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]
epoch 21 loss 21.5326 acc 0.9820: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]
epoch 22 loss 20.3412 acc 0.9831: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]
epoch 23 loss 17.2332 acc 0.9856: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]
epoch 24 loss 15.8865 acc 0.9873: 100%|███████| 391/391 [01:04<00:00,  6.10it/s]
epoch 25 loss 15.2423 acc 0.9875: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]
epoch 26 loss 14.0911 acc 0.9890: 100%|███████| 391/391 [01:03<00:00,  6.11it/s]
epoch 27 loss 12.7544 acc 0.9898: 100%|███████| 391/391 [01:04<00:00,  6.09it/s]
epoch 28 loss 11.8810 acc 0.9907: 100%|███████| 391/391 [01:04<00:00,  6.10it/s]
epoch 29 loss 10.7400 acc 0.9917: 100%|███████| 391/391 [01:04<00:00,  6.10it/s]
epoch 30 loss 10.9235 acc 0.9908: 100%|███████| 391/391 [01:04<00:00,  6.11it/s]

3. 网络评估

def evaluate(model_path):

    model = models.AlexNet(num_classes=10)
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    model.eval()

    with torch.no_grad():

        test_data = datasets.CIFAR10(root='data', train=False, download=False)
        dataloader = DataLoader(test_data,
                                shuffle=True,
                                batch_size=128,
                                collate_fn=collate_function)

        total, right = 0.0, 0.0
        for batch_images, batch_labels in dataloader:
            outputs = model(batch_images)
            y_pred = torch.argmax(outputs, dim=-1)

            total += len(batch_labels)
            right += torch.sum(y_pred == batch_labels)

        print(model_path)
        print('Test Acc %.4f %d/%d' % (right / total, right, total))


def evaluate_model():

    model_list = glob.glob('model/*.bin')
    for model in model_list:
        evaluate(model)
        print('-' * 30)


if __name__ == '__main__':
    evaluate_model()

评估过程输出:

model/alexnet-01.bin
Test Acc 0.8324 8324/10000
------------------------------
model/alexnet-02.bin
Test Acc 0.8612 8612/10000
------------------------------
model/alexnet-03.bin
Test Acc 0.8765 8765/10000
------------------------------
model/alexnet-04.bin
Test Acc 0.8833 8833/10000
------------------------------
model/alexnet-05.bin
Test Acc 0.8931 8931/10000
------------------------------
model/alexnet-06.bin
Test Acc 0.8979 8979/10000
------------------------------
model/alexnet-07.bin
Test Acc 0.9034 9034/10000
------------------------------
model/alexnet-08.bin
Test Acc 0.9061 9061/10000
------------------------------
model/alexnet-09.bin
Test Acc 0.9022 9022/10000
------------------------------
model/alexnet-10.bin
Test Acc 0.9086 9086/10000
------------------------------
model/alexnet-11.bin
Test Acc 0.9110 9110/10000
------------------------------
model/alexnet-12.bin
Test Acc 0.9146 9146/10000
------------------------------
model/alexnet-13.bin
Test Acc 0.9146 9146/10000
------------------------------
model/alexnet-14.bin
Test Acc 0.9162 9162/10000
------------------------------
model/alexnet-15.bin
Test Acc 0.9165 9165/10000
------------------------------
model/alexnet-16.bin
Test Acc 0.9145 9145/10000
------------------------------
model/alexnet-17.bin
Test Acc 0.9150 9150/10000
------------------------------
model/alexnet-18.bin
Test Acc 0.9185 9185/10000
------------------------------
model/alexnet-19.bin
Test Acc 0.9180 9180/10000
------------------------------
model/alexnet-20.bin
Test Acc 0.9180 9180/10000
------------------------------
model/alexnet-21.bin
Test Acc 0.9191 9191/10000
------------------------------
model/alexnet-22.bin
Test Acc 0.9176 9176/10000
------------------------------
model/alexnet-23.bin
Test Acc 0.9190 9190/10000
------------------------------
model/alexnet-24.bin
Test Acc 0.9197 9197/10000
------------------------------
model/alexnet-25.bin
Test Acc 0.9168 9168/10000
------------------------------
model/alexnet-26.bin
Test Acc 0.9186 9186/10000
------------------------------
model/alexnet-27.bin
Test Acc 0.9185 9185/10000
------------------------------
model/alexnet-28.bin
Test Acc 0.9185 9185/10000
------------------------------
model/alexnet-29.bin
Test Acc 0.9203 9203/10000
------------------------------
model/alexnet-30.bin
Test Acc 0.9175 9175/10000
------------------------------
未经允许不得转载:一亩三分地 » AlexNet CIFAR10 图像分类
评论 (0)

5 + 3 =