这篇文章我想讲解的是 scikit-learn 中 SVC 的二分类、多分类场景下 ovo、ovr 决策函数的计算过程,以了解 SVC 进行推理时的逻辑。从而加深对 SVC 的理解。
决策函数公式得到决策值之后,直接判断符号,可得出类别标签。
1. 二分类 SVC 的决策函数
from sklearn.svm import SVC from sklearn.metrics.pairwise import rbf_kernel from sklearn.datasets import make_classification import numpy as np def test(): inputs, labels = make_classification(n_samples=1000, n_features=20, n_informative=15, n_classes=2, random_state=42) estimator = SVC(kernel='rbf', gamma=0.3) estimator.fit(inputs, labels) sample = [inputs[0]] # 支持向量 svecs = estimator.support_vectors_ # 对偶系数 dcoef = estimator.dual_coef_ # 截距 icept = estimator.intercept_ # 参数 gamma = estimator.gamma print('模型对偶系数:', dcoef.shape) print('模型截距参数:', icept.shape) print('模型类别列表:', estimator.classes_) # 手动计算决策值 v1 = dcoef @ rbf_kernel(sample, svecs, gamma=gamma).T + icept print(v1.squeeze()) # 模型计算决策值 v2 = estimator.decision_function(sample) print(v2.squeeze()) if __name__ == '__main__': test()
2. 多分类 SVC OVO 决策函数
在 SVC 多分类的场景下,计算 ovo 决策结果,需要用到以下训练得到的属性值:
- n_support_ 表示每个分类支持向量的数量
- support_vectors_ 所有类别的支持向量
- dual_coef_ 这是是 SVC 训练得到的最重要的对偶系数,它实际等于 α * y,即:每一个样本的支持向量的朗格朗日乘子乘以该样本的标签(-1 或者 +1)
- intercept_ 是 ovo 中每一个模型的截距,对于 10 类别而言,会训练 (10 * (10 -1)) / 2 = 45 个二分类器,该属性中存储了这 45 个二分类器训练得到的截距。对应的顺序为:01 02 03 …09 12 13 14 15… 19..23…
这个计算过程理解的重点是对偶系数,其结构含义如下(截图官网文档):
对偶系数矩阵的理解可能较为复杂一些,但是它是理解 ovo 决策值如何计算的极其重要的一部分。
from sklearn.svm import SVC from sklearn.metrics.pairwise import rbf_kernel from sklearn.datasets import make_classification import numpy as np def test(): inputs, labels = make_classification(n_samples=100, n_features=20, n_informative=15, n_classes=10, random_state=42) estimator = SVC(kernel='rbf', gamma=0.3) estimator.fit(inputs, labels) sample = [inputs[0]] # 支持向量 svecs = estimator.support_vectors_ # 每个类别支持向量数量 nsupt = np.cumsum([0] + estimator.n_support_.tolist()) # 对偶系数 dcoef = estimator.dual_coef_ # 截距 icept = estimator.intercept_ # 参数 gamma = estimator.gamma print('模型对偶系数:', dcoef.shape) print('模型截距参数:', icept.shape) print('模型类别列表:', estimator.classes_) print('类别支持数量:', nsupt) # 手动计算决策值 # 为截距生成索引 intercept_indexes = [f'{i}{j}' for i in range(10) for j in range(i + 1, 10)] # 计算输入样本与所有支持向量的加权相似度 scores = dcoef * rbf_kernel(sample, svecs, gamma=gamma) # 计算每个类别的支持向量与其他类别的支持向量的相似度分数 class_scores = [] for s, e in zip(nsupt[:-1], nsupt[1:]): class_scores.append(scores[:, s:e].sum(axis=-1)) # 将分数展开 class_scores = np.array(class_scores) class_scores = class_scores.ravel() class_score_indexes = [f'{i}{j}' for i in range(10) for j in range(10) if i != j] v1 = [] for flag in intercept_indexes: a_index = class_score_indexes.index(flag) b_index = class_score_indexes.index(flag[::-1]) c_index = intercept_indexes.index(flag) v1. append(class_scores[a_index] + class_scores[b_index] + icept[c_index]) print(np.array(v1)) # 模型计算决策值 estimator.decision_function_shape = 'ovo' v2 = estimator.decision_function(sample) print(v2.squeeze()) print('手动计算和API计算结果:', np.all(np.array(v1) == v2)) if __name__ == '__main__': test()
模型对偶系数: (9, 100) 模型截距参数: (45,) 模型类别列表: [0 1 2 3 4 5 6 7 8 9] 类别支持数量: [ 0 10 20 30 41 51 60 69 79 89 100] [ 4.26630481e-04 1.15984298e-11 -9.09090910e-02 1.45700034e-09 1.00000000e-01 1.00000000e-01 3.20884059e-10 -1.00000000e+00 -9.09094941e-02 -3.93785796e-04 -9.09090910e-02 -3.65702293e-04 9.95395590e-02 9.95395438e-02 -3.93785411e-04 -1.00039376e+00 -9.09093704e-02 -9.09090918e-02 1.54111940e-09 1.00000000e-01 1.00000000e-01 3.12410279e-10 -1.00000000e+00 -9.09091250e-02 9.09090909e-02 1.81818182e-01 1.81818182e-01 9.09090922e-02 -9.09090904e-01 4.36981889e-09 9.99999994e-02 9.99999983e-02 1.11681132e-10 -1.00000000e+00 -9.09090851e-02 -4.46671502e-09 -1.00000006e-01 -9.99414063e-01 -1.81818177e-01 -9.99999995e-02 -9.99414063e-01 -1.81818177e-01 -1.00000000e+00 -9.09090851e-02 9.09090915e-01] [ 4.26630481e-04 1.15984298e-11 -9.09090910e-02 1.45700034e-09 1.00000000e-01 1.00000000e-01 3.20884059e-10 -1.00000000e+00 -9.09094941e-02 -3.93785796e-04 -9.09090910e-02 -3.65702293e-04 9.95395590e-02 9.95395438e-02 -3.93785411e-04 -1.00039376e+00 -9.09093704e-02 -9.09090918e-02 1.54111940e-09 1.00000000e-01 1.00000000e-01 3.12410279e-10 -1.00000000e+00 -9.09091250e-02 9.09090909e-02 1.81818182e-01 1.81818182e-01 9.09090922e-02 -9.09090904e-01 4.36981889e-09 9.99999994e-02 9.99999983e-02 1.11681132e-10 -1.00000000e+00 -9.09090851e-02 -4.46671502e-09 -1.00000006e-01 -9.99414063e-01 -1.81818177e-01 -9.99999995e-02 -9.99414063e-01 -1.81818177e-01 -1.00000000e+00 -9.09090851e-02 9.09090915e-01] 手动计算和API计算结果: True
3. 多分类 SVC OVR 决策函数
在 SVC 中,当我们把 decision_function_shape 设置 ovr 时,实际内部仍然会先计算 ovo 决策值,然后再由 ovo 的决策值转换为 ovr 的决策值。ovr 中,每一个类别都对应了预测分数,我们最后将其归为分数最大的类别标签即可。
from sklearn.svm import SVC from sklearn.metrics.pairwise import rbf_kernel from sklearn.datasets import make_classification import numpy as np def test(): inputs, labels = make_classification(n_samples=100, n_features=20, n_informative=15, n_classes=10, random_state=42) estimator = SVC(kernel='rbf', gamma=0.3) estimator.fit(inputs, labels) sample = [inputs[0]] # 支持向量 svecs = estimator.support_vectors_ # 每个类别支持向量数量 nsupt = np.cumsum([0] + estimator.n_support_.tolist()) # 对偶系数 dcoef = estimator.dual_coef_ # 截距 icept = estimator.intercept_ # 参数 gamma = estimator.gamma print('模型对偶系数:', dcoef.shape) print('模型截距参数:', icept.shape) print('模型类别列表:', estimator.classes_) print('类别支持数量:', nsupt) # 手动计算决策值 estimator.decision_function_shape = 'ovo' ovo = estimator.decision_function(sample) # ovr 分数是由 ovo 转换得到 from sklearn.utils.multiclass import _ovr_decision_function # 第一个参数:每一个分类器预测的类别 # 第二个参数:预测为 +1 类别的分数或概率 # 第三个参数:类别的数量 v1 = _ovr_decision_function(ovo < 0, -ovo, len(estimator.classes_)) print(v1.squeeze()) # 模型计算决策值 estimator.decision_function_shape = 'ovr' v2 = estimator.decision_function(sample) print(v2.squeeze()) if __name__ == '__main__': test()
模型对偶系数: (9, 100) 模型截距参数: (45,) 模型类别列表: [0 1 2 3 4 5 6 7 8 9] 类别支持数量: [ 0 10 20 30 41 51 60 69 79 89 100] [ 5.83489857 1.83461706 4.83489581 7.97222223 3.83489343 -0.21688867 0.78311133 2.83489581 9.29938003 6.97222241] [ 5.83489857 1.83461706 4.83489581 7.97222223 3.83489343 -0.21688867 0.78311133 2.83489581 9.29938003 6.97222241]