《手写数字识别器》(五)Scikit-Learn

《手写数字识别器》中会应用支持向量机算法、以及相关的参数搜索方法。这两个在 Scikit-Learn 中都有实现。这一节,我们将会学习两个相关的 API 的使用。

  1. 数据介绍
  2. 算法使用
  3. 网格搜索

1. 数据介绍

鸢尾花数据集是一个非常有名的机器学习数据集,它由美国统计学家 R.A. Fisher 收集。这个数据集包含了150个鸢尾花的测量数据,每个数据包含了花萼长度、花萼宽度、花瓣长度、花瓣宽度四个特征,它们用来描述鸢尾花的类型。

鸢尾花数据集共有三个类别,分别是Setosa Iris(狗尾草鸢尾)、Versicolour Iris(彩色鸢尾)和Virginica Iris(蓝色鸢尾)。这三个类别的样本数量相等,每个类别有50个样本。鸢尾花数据集通常用于分类任务。

我们在这一小节,主要学习两个点:

  1. 了解数据集相关信息
  2. 掌握数据集分割使用
from sklearn.datasets import load_iris
from collections import Counter
from sklearn.model_selection import train_test_split

# 1. 数据集介绍
def test01():
    # 加载鸢尾花数据集
    data = load_iris()
    # 特征名字
    print(data.feature_names)
    # 特征值
    print(data.data)
    # 目标值名字
    print(data.target_names)
    # 目标值
    print(data.target)
    # 样本分布
    print(Counter(data.target))


# 2. 分割数据集
def test02():
    data = load_iris()
    # 数据集分割
    # test_size : 设置测试集占比
    # stratify : 设置按照类别比例分割
    x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, stratify=data.target)
    print(Counter(y_train))
    print(Counter(y_test))


if __name__ == '__main__':
    test02()

2. 算法使用

sklearn 的 SVC 使用较为简单,重点需要了解其几个重要参数:

  1. C:float 默认值是1.0。用于平衡分类器的复杂度和错误惩罚的参数。C 越大,对训练样本分类的正确率越高,但泛化能力越弱;C 越小,对误分类的惩罚减小,允许容错,泛化能力较强。
  2. kernel:用于指定使用的核函数,主要包括:linear、poly、rbf、sigmoid
  3. gamma:参数值为 {‘scale’, ‘auto’} or float, default=’scale’,注意该值为非负数,并且该参数会使用到 poly、sigmoid、rbf 核函数计算过程中
  4. degree:int 默认值是 3,用于核函数计算时的参数
  5. coef0:float 默认是值为 0.0,它是 poly 和 sigmoid 的截距

核函数文档链接:https://scikit-learn.org/stable/modules/svm.html#svm-kernels

from sklearn.svm import SVC
from sklearn.datasets import load_iris
from collections import Counter
from sklearn.model_selection import train_test_split


def test():

    # 数据处理
    data = load_iris()
    x_train, x_test, y_train, y_test = \
        train_test_split(data.data, data.target, test_size=0.2, random_state=67)

    # 算法训练
    estimator = SVC()
    estimator.fit(x_train, y_train)

    # 算法评估
    acc = estimator.score(x_test, y_test)
    print('Acc: %.2f' % acc)


if __name__ == '__main__':
    test()

3. 网格搜索

在scikit-learn 库中提供了 GridSearchCV(网格搜索) 用于在给定的参数范围内自动搜索最优的参数组合。具体的原理是在指定的参数网格上执行训练/验证过程,并记录下每个参数组合对应的模型表现(例如准确率),最后选择最优的参数组合。

  • 创建一个估计器对象,并为其指定一组参数。
  • 创建一个 GridSearchCV 对象,并为其指定估计器对象和参数网格。
  • 训练 GridSearchCV 对象,并获取最优的参数组合。

请注意,GridSearchCV 是一种非常耗时的方法,尤其是当参数网格较大时,因此建议在实验过程中使用较小的参数网格或者采用其他更为高效的搜索策略。

from sklearn.svm import SVC
from sklearn.datasets import load_iris
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
import numpy as np
import pandas as pd


def test():
    # 数据处理
    data = load_iris()
    x_train, x_test, y_train, y_test = \
        train_test_split(data.data, data.target, test_size=0.2, random_state=67)

    # 超参数搜索+模型训练
    param_grid = {
        'C': np.arange(0.1, 2.0, 0.1),
        'kernel': ['linear', 'poly', 'sigmoid', 'rbf'],
        'degree': [1, 3, 5, 7, 9],
        'gamma': ['scale', 'auto']
    }

    # 如果想固定某个参数,可以通过 estimator 来指定设置了固定参数的模型
    estimator = GridSearchCV(estimator=SVC(), param_grid=param_grid, cv=3)
    estimator.fit(x_train, y_train)
    print('最优模型:', estimator.best_estimator_)
    print('最优参数:', estimator.best_params_)
    print('最优分数:', '%.2f' % estimator.best_score_)

    pd.DataFrame(estimator.cv_results_).to_csv('demo.csv')


if __name__ == '__main__':
    test()

未经允许不得转载:一亩三分地 » 《手写数字识别器》(五)Scikit-Learn
评论 (0)

1 + 9 =