经过前面的准备,我们将会进行初次的算法模型训练以及封装,并测试效果。从中去分析不足,然后在后续内容中针对不足进行优化。
import numpy as np from sklearn.svm import SVC import cv2 from skimage import io import glob import pickle import os
1. 数据处理
首先,我们先对训练数据进行读取,并提取图像特征。需要说明一点,由于图像本身就是由一系列像素数据组成,所以可以把每一个像素值作为图像的特征。即:我们直接使用原始的像素数据训练算法模型。
一般来讲,这样的效果很不好,大概率会出现过拟合的问题,即:训练效果非常不错,但是测试效果很糟糕。
def extract_feature(image):
# 将2d数组展开为1d数组
image = image.reshape(-1)
# 图像数据归一化
image = image / 255
return image
def load_data(data_type='train'):
image_fnames = glob.glob(f'data/{data_type}/[0-9]-[0-9]*.png')
images, labels = [], []
for fname in image_fnames:
# 读取图像数据
image = cv2.imread(fname)
# 转换为灰度图
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 提取图像特征
image = extract_feature(image)
images.append(image)
# 解析标签
fname = os.path.basename(fname)
label = int(fname.split('-')[0])
labels.append(label)
images = np.array(images)
labels = np.array(labels)
return images, labels
2. 训练评估
算法的训练和评估都比较简单。这里评估算法模型时,使用了最简单的准确率,其实对于分类模型的评估,我们也可以使用其他更为精细的方法,例如:精度、召回率、F1-Score 等。
def estimator_train():
images, labels = load_data(data_type='train')
estimator = SVC()
estimator.fit(images, labels)
acc = estimator.score(images, labels)
print('训练集 Acc:', '%.2f' % acc)
# 存储模型
pickle.dump(estimator, open('model/estimator.pth', 'wb'))
def estimator_eval():
images, labels = load_data(data_type='test')
estimator = pickle.load(open('model/estimator.pth', 'rb'))
acc = estimator.score(images, labels)
print('测试集 Acc:', '%.2f' % acc)
3. 模型封装
我们将训练好的算法模型,封装成 Estimator1 类,便于在图形界面中进行调用。该类需要对外提供两个方法:
- train 方法,返回训练集和测试集的准确率
- predict 方法,返回预测的图像标签
import pickle
from sklearn.svm import SVC
import cv2
import glob
import os
import numpy as np
class Estimator1:
def __init__(self):
self.estimator = pickle.load(open('model/estimator.pth', 'rb'))
def extract_feature(self, image):
image = image.reshape(-1)
image = image / 255
return image
def load_data(self, data_type='train'):
image_fnames = glob.glob(f'data/{data_type}/[0-9]-[0-9]*.png')
images, labels = [], []
for fname in image_fnames:
# 读取图像数据
image = cv2.imread(fname)
# 转换灰度图
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 提取图像特征
image = self.extract_feature(image)
images.append(image)
# 解析标签
fname = os.path.basename(fname)
label = int(fname.split('-')[0])
labels.append(label)
images = np.array(images)
labels = np.array(labels)
return images, labels
def train(self):
images, labels = self.load_data(data_type='train')
estimator = SVC()
estimator.fit(images, labels)
train_acc = estimator.score(images, labels)
images, labels = self.load_data(data_type='test')
test_acc = estimator.score(images, labels)
pickle.dump(estimator, open('model/estimator.pth', 'wb'))
self.estimator = estimator
return train_acc, test_acc
def predict(self):
image = cv2.imread('data/train/temp.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = self.extract_feature(image)
label = self.estimator.predict([image])
return label[0]
MainFrame 增加代码如下:
import tkinter as tk
import os
from Config import *
class MainFrame(tk.Tk):
def __init__(self):
......
# 初始化模型
from Estimator1 import Estimator1
self.model = Estimator1()
......
# 下面为按钮绑定函数
def clear(self):
print('清屏')
self.ccav.clear_canvas()
def save(self):
print('保存')
self.ccav.save()
def show(self):
print('展开')
if os.name == 'nt':
os.startfile(os.path.abspath('./data/train'))
else:
subprocess.run(['open', os.path.abspath('./data/train')])
def open(self):
print('打开')
self.ccav.open()
def train(self):
print('训练')
train_acc, test_acc = self.model.train()
self.sbar.set_status('训练: %.2f 测试: %.2f' % (train_acc, test_acc))
def predict(self):
print('推理')
image = self.ccav.generate_image()
image.save('data/train/temp.png', 'png')
label = self.model.predict()
self.sbar.set_status('预测: %d' % label)



冀公网安备13050302001966号