《手写数字识别器》(九)位置敏感

经过第八章节的特征优化,我们的手写数字识别准确率有着明显的提升。但是仍然存在一个显而易见的问题,即:我们绘制的数字是对位置很敏感的。接下来,我们将会去探讨解决这一问题的方法。

1. 图像数据增强

数据增强是一种通过对训练图像进行一系列随机改变,从而扩大训练数据集规模的技术。这种技术的核心目标是降低模型对某些属性的依赖性,以提高模型的泛化能力。

具体来说,我们可以采用不同方式,让感兴趣的物体出现在图像的不同位置上,这样可以减轻模型对物体出现位置的依赖性。同时,也可以调整图像的亮度、色彩等因素,以降低模型对色彩的敏感度。以下是几种常见的针对图像的位置敏感性的数据增强方法:

  1. 裁剪:从原始图像中随机选择一个矩形区域,并将这个区域作为新的图像输入。这样可以改变目标物体在图像中的位置,让模型学会忽略位置信息。
  2. 翻转:将图像沿着水平或垂直轴翻转,可以使目标物体出现在不同的位置上。这种方法也可以增加训练数据的数量。
  3. 缩放和旋转:通过对图像进行缩放或旋转,可以改变目标物体的大小和角度,使得模型更加鲁棒。
  4. 平移:在图像平面上移动图像,使目标物体在图像中的位置发生变化。

图像增强方法最直接的缺点,会使得训练集膨胀,需要更多的计算资源和时间。同时,也可能会给数据集带来新的噪声。但是,该方法仍然是一种非常不错的提高模型泛化能力的方法。

2. 图像中心化

由于我们输入的图像较为简单,纯色背景上绘制单一颜色的数字。所以,我们选择使用另外一种方法,即:图像物体的中心化。其思路为:

  1. 首先,在训练或推理之前,检测图像中的物体
  2. 接着,将物体缩放等比例到指定大小左右
  3. 然后,将物体设置到图像的中心位置
  4. 最后,进行训练或推理

这一步的图像预处理操作,我们主要基于 opencv 来实现。

import cv2
import matplotlib.pyplot as plt
import numpy as np
from skimage import io
import warnings
warnings.filterwarnings('ignore')


def show_image(image, title, index):
    image = np.where(image == np.min(image), 255, 0).astype(np.uint8)
    plt.subplot(1, 2, index)
    plt.axis('off')
    plt.imshow(image, cmap='gray')
    plt.title(title)


def image_center(image):
    print(np.unique(image))
    # 背景黑色,前景(数字)为白色
    image = np.where(image == np.min(image), 255, 0).astype(np.uint8)

    # 计算数字图像的轮廓(背景黑色,前景白色)
    contours, _ = cv2.findContours(image, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE)
    points = []
    for contour in contours:
        # 获得给定的轮廓内最小的矩形
        x, y, w, h = cv2.boundingRect(contour.squeeze())
        points.append((x, y, x + w, y + h))
    points = np.array(points)
    x1 = np.min(points[:, 0])
    y1 = np.min(points[:, 1])
    x2 = np.max(points[:, 2])
    y2 = np.max(points[:, 3])

    # 轮廓内数字图像
    image = image[y1: y2, x1: x2]

    # 缩放到中心区域大小
    scale_factor = min((300 / np.array(image.shape)).astype(np.float32))
    from skimage import transform
    image = transform.rescale(image,
                              scale_factor,
                              mode='constant',
                              cval=0,
                              anti_aliasing=True,
                              preserve_range=True)
    # 背景改为白色,数字改为黑色
    image = np.where(image == 0, 255, 0).astype(np.uint8)
    # 四周重新填充,保持原图像大小
    pad_shape = np.array((500, 500)) - image.shape
    pad_before = (pad_shape / 2).astype(np.uint32)
    pad_after = pad_shape - pad_before

    image = np.pad(image,
                   [(pad_before[0], pad_after[0]), (pad_before[1], pad_after[1])],
                   mode='constant',
                   constant_values=255)
    return image


if __name__ == '__main__':
    image = io.imread('../data/train/8-8.png')
    image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    show_image(image, 'orgin', 1)

    image = image_center(image)
    show_image(image, 'center', 2)
    plt.show()

3. 应用图像中心化

import pickle
from sklearn.svm import SVC
import cv2
import glob
import os
import numpy as np
from skimage import feature
from skimage import transform
from skimage import io


class Estimator2:

    def __init__(self):
        estimator_path = 'model/estimator.pth'
        if os.path.exists(estimator_path):
            self.estimator = pickle.load(open(estimator_path, 'rb'))

    # 使用 HOG 特征
    def extract_feature(self, image):
        # image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        image = feature.hog(image,
                            orientations=9,
                            pixels_per_cell=(100, 100),
                            cells_per_block=(3, 3),
                            visualize=False)
        return image

    # 图像中心化
    def image_center(self, image):
        # 背景黑色,前景(数字)为白色
        image = np.where(image == np.min(image), 255, 0).astype(np.uint8)

        # 计算数字图像的轮廓(背景黑色,前景白色)
        contours, _ = cv2.findContours(image, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE)
        points = []
        for contour in contours:
            # 获得给定的轮廓内最小的矩形
            x, y, w, h = cv2.boundingRect(contour.squeeze())
            points.append((x, y, x + w, y + h))
        points = np.array(points)
        x1 = np.min(points[:, 0])
        y1 = np.min(points[:, 1])
        x2 = np.max(points[:, 2])
        y2 = np.max(points[:, 3])

        # 轮廓内数字图像
        image = image[y1: y2, x1: x2]

        # 缩放到中心区域大小
        scale_factor = min((300 / np.array(image.shape)).astype(np.float32))
        image = transform.rescale(image,
                                  scale_factor,
                                  mode='constant',
                                  cval=0,
                                  anti_aliasing=True,
                                  preserve_range=True)
        # 背景改为白色,数字改为黑色
        image = np.where(image == 0, 255, 0).astype(np.uint8)
        # 四周重新填充,保持原图像大小
        pad_shape = np.array((500, 500)) - image.shape
        pad_before = (pad_shape / 2).astype(np.uint32)
        pad_after = pad_shape - pad_before

        image = np.pad(image,
                       [(pad_before[0], pad_after[0]), (pad_before[1], pad_after[1])],
                       mode='constant',
                       constant_values=255)
        return image


    def load_data(self, data_type='train'):
        image_fnames = glob.glob(f'data/{data_type}/[0-9]-[0-9]*.png')
        images, labels = [], []
        for fname in image_fnames:
            # 读取图像数据
            image = cv2.imread(fname)
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

            # 图像中心化
            image = self.image_center(image)

            # 提取图像特征
            image = self.extract_feature(image)
            images.append(image)

            # 解析标签
            fname = os.path.basename(fname)
            label = int(fname.split('-')[0])
            labels.append(label)

        images = np.array(images)
        labels = np.array(labels)

        return images, labels

    def train(self):
        images, labels = self.load_data(data_type='train')
        estimator = SVC()
        estimator.fit(images, labels)
        train_acc = estimator.score(images, labels)

        images, labels = self.load_data(data_type='test')
        test_acc = estimator.score(images, labels)
        pickle.dump(estimator, open('model/estimator.pth', 'wb'))
        self.estimator = estimator

        return train_acc, test_acc

    def predict(self):
        image = cv2.imread('data/train/temp.png')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image = self.image_center(image)
        image = self.extract_feature(image)
        label = self.estimator.predict([image])

        return label[0]


if __name__ == '__main__':

    model = Estimator2()
    model.train()
未经允许不得转载:一亩三分地 » 《手写数字识别器》(九)位置敏感
评论 (0)

4 + 7 =