《手写数字识别器》(七)初次训练

经过前面的准备,我们将会进行初次的算法模型训练以及封装,并测试效果。从中去分析不足,然后在后续内容中针对不足进行优化。

import numpy as np
from sklearn.svm import SVC
import cv2
from skimage import io
import glob
import pickle
import os

1. 数据处理

首先,我们先对训练数据进行读取,并提取图像特征。需要说明一点,由于图像本身就是由一系列像素数据组成,所以可以把每一个像素值作为图像的特征。即:我们直接使用原始的像素数据训练算法模型。

一般来讲,这样的效果很不好,大概率会出现过拟合的问题,即:训练效果非常不错,但是测试效果很糟糕。

def extract_feature(image):
    # 将2d数组展开为1d数组
    image = image.reshape(-1)
    # 图像数据归一化
    image = image / 255
    return image


def load_data(data_type='train'):

    image_fnames = glob.glob(f'data/{data_type}/[0-9]-[0-9]*.png')
    images, labels = [], []
    for fname in image_fnames:
        # 读取图像数据
        image = cv2.imread(fname)
        # 转换为灰度图
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # 提取图像特征
        image = extract_feature(image)
        images.append(image)
        # 解析标签
        fname = os.path.basename(fname)
        label = int(fname.split('-')[0])
        labels.append(label)

    images = np.array(images)
    labels = np.array(labels)

    return images, labels

2. 训练评估

算法的训练和评估都比较简单。这里评估算法模型时,使用了最简单的准确率,其实对于分类模型的评估,我们也可以使用其他更为精细的方法,例如:精度、召回率、F1-Score 等。

def estimator_train():

    images, labels = load_data(data_type='train')
    estimator = SVC()
    estimator.fit(images, labels)
    acc = estimator.score(images, labels)
    print('训练集 Acc:', '%.2f' % acc)
    # 存储模型
    pickle.dump(estimator, open('model/estimator.pth', 'wb'))


def estimator_eval():
    images, labels = load_data(data_type='test')
    estimator = pickle.load(open('model/estimator.pth', 'rb'))
    acc = estimator.score(images, labels)
    print('测试集 Acc:', '%.2f' % acc)

3. 模型封装

我们将训练好的算法模型,封装成 Estimator1 类,便于在图形界面中进行调用。该类需要对外提供两个方法:

  1. train 方法,返回训练集和测试集的准确率
  2. predict 方法,返回预测的图像标签
import pickle
from sklearn.svm import SVC
import cv2
import glob
import os
import numpy as np


class Estimator1:

    def __init__(self):
        self.estimator = pickle.load(open('model/estimator.pth', 'rb'))

    def extract_feature(self, image):
        image = image.reshape(-1)
        image = image / 255
        return image

    def load_data(self, data_type='train'):
        image_fnames = glob.glob(f'data/{data_type}/[0-9]-[0-9]*.png')
        images, labels = [], []
        for fname in image_fnames:
            # 读取图像数据
            image = cv2.imread(fname)
            # 转换灰度图
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            # 提取图像特征
            image = self.extract_feature(image)
            images.append(image)

            # 解析标签
            fname = os.path.basename(fname)
            label = int(fname.split('-')[0])
            labels.append(label)

        images = np.array(images)
        labels = np.array(labels)

        return images, labels

    def train(self):
        images, labels = self.load_data(data_type='train')
        estimator = SVC()
        estimator.fit(images, labels)
        train_acc = estimator.score(images, labels)

        images, labels = self.load_data(data_type='test')
        test_acc = estimator.score(images, labels)
        pickle.dump(estimator, open('model/estimator.pth', 'wb'))
        self.estimator = estimator

        return train_acc, test_acc

    def predict(self):
        image = cv2.imread('data/train/temp.png')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image = self.extract_feature(image)
        label = self.estimator.predict([image])

        return label[0]

MainFrame 增加代码如下:

import tkinter as tk
import os
from Config import *



class MainFrame(tk.Tk):

    def __init__(self):
        ......
        # 初始化模型
        from Estimator1 import Estimator1
        self.model = Estimator1()

    ......

    # 下面为按钮绑定函数
    def clear(self):
        print('清屏')
        self.ccav.clear_canvas()

    def save(self):
        print('保存')
        self.ccav.save()

    def show(self):
        print('展开')
        if os.name == 'nt':
            os.startfile(os.path.abspath('./data/train'))
        else:
            subprocess.run(['open', os.path.abspath('./data/train')])

    def open(self):
        print('打开')
        self.ccav.open()

    def train(self):
        print('训练')
        train_acc, test_acc = self.model.train()
        self.sbar.set_status('训练: %.2f 测试: %.2f' % (train_acc, test_acc))

    def predict(self):
        print('推理')
        image = self.ccav.generate_image()
        image.save('data/train/temp.png', 'png')
        label = self.model.predict()
        self.sbar.set_status('预测: %d' % label)
未经允许不得转载:一亩三分地 » 《手写数字识别器》(七)初次训练
评论 (0)

4 + 3 =