经过前面的准备,我们将会进行初次的算法模型训练以及封装,并测试效果。从中去分析不足,然后在后续内容中针对不足进行优化。
import numpy as np from sklearn.svm import SVC import cv2 from skimage import io import glob import pickle import os
1. 数据处理
首先,我们先对训练数据进行读取,并提取图像特征。需要说明一点,由于图像本身就是由一系列像素数据组成,所以可以把每一个像素值作为图像的特征。即:我们直接使用原始的像素数据训练算法模型。
一般来讲,这样的效果很不好,大概率会出现过拟合的问题,即:训练效果非常不错,但是测试效果很糟糕。
def extract_feature(image): # 将2d数组展开为1d数组 image = image.reshape(-1) # 图像数据归一化 image = image / 255 return image def load_data(data_type='train'): image_fnames = glob.glob(f'data/{data_type}/[0-9]-[0-9]*.png') images, labels = [], [] for fname in image_fnames: # 读取图像数据 image = cv2.imread(fname) # 转换为灰度图 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 提取图像特征 image = extract_feature(image) images.append(image) # 解析标签 fname = os.path.basename(fname) label = int(fname.split('-')[0]) labels.append(label) images = np.array(images) labels = np.array(labels) return images, labels
2. 训练评估
算法的训练和评估都比较简单。这里评估算法模型时,使用了最简单的准确率,其实对于分类模型的评估,我们也可以使用其他更为精细的方法,例如:精度、召回率、F1-Score 等。
def estimator_train(): images, labels = load_data(data_type='train') estimator = SVC() estimator.fit(images, labels) acc = estimator.score(images, labels) print('训练集 Acc:', '%.2f' % acc) # 存储模型 pickle.dump(estimator, open('model/estimator.pth', 'wb')) def estimator_eval(): images, labels = load_data(data_type='test') estimator = pickle.load(open('model/estimator.pth', 'rb')) acc = estimator.score(images, labels) print('测试集 Acc:', '%.2f' % acc)
3. 模型封装
我们将训练好的算法模型,封装成 Estimator1 类,便于在图形界面中进行调用。该类需要对外提供两个方法:
- train 方法,返回训练集和测试集的准确率
- predict 方法,返回预测的图像标签
import pickle from sklearn.svm import SVC import cv2 import glob import os import numpy as np class Estimator1: def __init__(self): self.estimator = pickle.load(open('model/estimator.pth', 'rb')) def extract_feature(self, image): image = image.reshape(-1) image = image / 255 return image def load_data(self, data_type='train'): image_fnames = glob.glob(f'data/{data_type}/[0-9]-[0-9]*.png') images, labels = [], [] for fname in image_fnames: # 读取图像数据 image = cv2.imread(fname) # 转换灰度图 image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 提取图像特征 image = self.extract_feature(image) images.append(image) # 解析标签 fname = os.path.basename(fname) label = int(fname.split('-')[0]) labels.append(label) images = np.array(images) labels = np.array(labels) return images, labels def train(self): images, labels = self.load_data(data_type='train') estimator = SVC() estimator.fit(images, labels) train_acc = estimator.score(images, labels) images, labels = self.load_data(data_type='test') test_acc = estimator.score(images, labels) pickle.dump(estimator, open('model/estimator.pth', 'wb')) self.estimator = estimator return train_acc, test_acc def predict(self): image = cv2.imread('data/train/temp.png') image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = self.extract_feature(image) label = self.estimator.predict([image]) return label[0]
MainFrame 增加代码如下:
import tkinter as tk import os from Config import * class MainFrame(tk.Tk): def __init__(self): ...... # 初始化模型 from Estimator1 import Estimator1 self.model = Estimator1() ...... # 下面为按钮绑定函数 def clear(self): print('清屏') self.ccav.clear_canvas() def save(self): print('保存') self.ccav.save() def show(self): print('展开') if os.name == 'nt': os.startfile(os.path.abspath('./data/train')) else: subprocess.run(['open', os.path.abspath('./data/train')]) def open(self): print('打开') self.ccav.open() def train(self): print('训练') train_acc, test_acc = self.model.train() self.sbar.set_status('训练: %.2f 测试: %.2f' % (train_acc, test_acc)) def predict(self): print('推理') image = self.ccav.generate_image() image.save('data/train/temp.png', 'png') label = self.model.predict() self.sbar.set_status('预测: %d' % label)