代价复杂度剪枝(Cost-Complexity Pruning)

在 scikit-learn 的决策树实现中,使用 CCP(Cost-Complexity Pruning)代价复杂度剪枝,用于避免过拟合并提高决策树的泛化能力。

1. 剪枝原理

决策树中包含了很多子树,一棵子树是否应该剪掉,得通过某个指标来对其代价复杂度进行度量,CCP 使用下面的指标来进行度量:

  • R(T) 表示子树的加权不纯度(以基于基尼不纯度的分类决策树为例)
  • \(\tilde{T}\) 表示子树的叶子节点数量
  • α 表示调节参数

通过 CCP 公式,我们可以从不纯度叶子结点数量两个角度来度量一棵树的复杂度。其中,不纯度表示树的拟合效果,叶子结点数量表示树的复杂度。

  1. 如果 \(R_{\alpha}(T)\) 值比较大,说明决策树更加复杂
  2. 如果 \(R_{\alpha}(T)\) 值比较小,说明决策树更加简单

在公式中,超参数 α 用来调节度量决策树代价复杂度时,我们更倾向于关心不纯度,还是叶子结点数量:

  1. 如果设置较大的 α 值,说明我们更加关心叶子结点数量
  2. 如果设置较小的 α 值,说明我们更加关心叶树的不纯度

另外,我们还需要理解一点是:树的剪枝过程并不是只剪一次,如果一直剪下去的话,最后只会剩下一个根节点。并且,每次剪枝都会在上次剪枝后的树上进行。关于 α 还应该理解:

  1. 如果 α 值过大,倾向于过度对决策树进行剪枝
  2. 如果 α 值较小,倾向于对决策树进行更加保守的剪枝

2. 复杂度计算

树的复杂度计算包括对某个子树的计算、以及对叶子结点的计算。树的复杂度计算公式如下:

  • \(T_t\) 表示以结点 t 为根节点的子树
  • \(\tilde{T}\) 表示以结点 t 为根节点的子树的叶子结点数量

如果要计算某个叶子结点的复杂度 \(R(t)\),则使用下面的公式:

  • R(t) 表示某个叶子结点的加权不纯度。

我们以下面的树(使用熵来度量每个节点的不纯度,值越大,不纯度越高)为例,来手动计算下树和叶子结点的复杂度:

计算 R(t=0) 树的 CCP 复杂度,并且 α =0.1,则:

加权不纯度:

  • 第 3 号叶子结点:0.0 * 1/14 = 0
  • 第 4 号叶子结点:0.0 * 4/14 = 0
  • 第 5 号叶子结点:0.722 * 5/14 = 0.258
  • 第 6 号叶子结点:0.0 * 4/14

叶子结点复杂度:0.1 * 4 = 0.4

最终得到 R(t=0) 子树的复杂度为: 0.258 + 0.4 = 0.658

计算 R(t=5) 叶子结点的 CCP 复杂度, 0.722 * 5/14 + 0.1 = 0.358

Day,Weather,Temperature,Humidity,Wind,Play
1,Sunny,25,High,Weak,No
2,Sunny,24,High,Strong,No
3,Overcast,28,High,Weak,Yes
4,Rain,22,High,Weak,Yes
5,Rain,20,Normal,Weak,Yes
6,Rain,18,Normal,Strong,No
7,Overcast,21,Normal,Strong,Yes
8,Sunny,23,High,Weak,No
9,Sunny,19,Normal,Weak,Yes
10,Rain,24,Normal,Weak,Yes
11,Sunny,23,Normal,Strong,Yes
12,Overcast,26,High,Strong,Yes
13,Overcast,27,Normal,Weak,Yes
14,Rain,22,High,Strong,No
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')


data = pd.read_csv('data-4.csv')
inputs = data[['Weather', 'Temperature', 'Humidity', 'Wind']]
labels = data['Play']
# 处理类别特征
inputs = pd.get_dummies(inputs)

estimator = DecisionTreeClassifier(criterion='entropy', ccp_alpha=0.2)
estimator.fit(inputs, labels)

plt.figure(figsize=(20, 15))
plot_tree(estimator,
          feature_names=list(inputs.columns),
          class_names=None,
          node_ids=True,
          precision=3,
          max_depth=None,
          fontsize=18,
          filled=True)
plt.show()

3. α 临界值

对于决策树中的每一个非叶子节点,计算其剪枝前剪枝后的复杂度差值,如下公式所示:

如果这个差值大于 0,说明剪枝之后能够降低整个决策树的复杂度。我们将该公式展开:

  • α 值大于左半部分的值时,就进行剪枝。即:左半部分值其实是一个临界 α 值,最大不剪枝的 α 值
  • α 值也表示平均不纯度的增加程度,即:剪掉某个子树之后决策树不纯度的增加程度

当有多个非叶子结点都满足剪枝条件时,一般选择 α 值最小的子树进行剪枝。原因是为了在最小的不纯度增加的代价下,实现模型的简化

接下来,通过一个例子来理解 α 临界值的计算过程:

完全生长的决策树
node#1node#4node#5node#7node#9
\(R(t)\)0.414*12/240.98*12/240.863*7/240.65*6/240.918*3/24
\(R(T_{t})\)00000
\(|\tilde{T}|\)25432
α 临界值0.2070.12250.08390.081250.11475

我们选择最小临界值 α= 0.08125 ,并设置 α 比该值稍微大一些。

第一次剪枝后的决策树(node#7)

第二次剪枝后的决策树(node#5)
第三次剪枝后的决策树(node#1)
第四次剪枝后的决策树(node#0)

在剪枝之后,在新的树上,计算所有非叶子节点的剪枝代价,继续选择 α 最小的结点进行剪枝。当然,按照这种方式一直剪下去的话,无论多么复杂的决策树都会被剪成一个只有根节点的树桩,并且剪得越多,树的不纯度就越高,并且在剪枝的过程中也会获得多个 α 值,这些 α 值也组成了一个决策树的剪枝路径。

我们可以使用 from sklearn.tree import DecisionTreeClassifiercost_complexity_pruning_path 函数来计算决策树的剪枝路径:

Day,Weather,Temperature,Humidity,Wind,Play
1,Sunny,25,High,Weak,No
2,Sunny,24,High,Strong,No
3,Overcast,28,High,Weak,Yes
4,Rain,22,High,Weak,Yes
5,Rain,20,Normal,Weak,Yes
6,Rain,18,Normal,Strong,No
7,Overcast,21,Normal,Strong,Yes
8,Sunny,23,High,Weak,No
9,Sunny,19,Normal,Weak,Yes
10,Rain,24,Normal,Weak,Yes
11,Sunny,23,Normal,Strong,Yes
12,Overcast,26,High,Strong,Yes
13,Overcast,27,Normal,Weak,Yes
14,Rain,22,High,Strong,No
15,Sunny,26,High,Strong,No
16,Sunny,22,Normal,Weak,Yes
17,Overcast,25,Normal,Strong,Yes
18,Rain,21,High,Weak,Yes
19,Sunny,24,High,Strong,No
20,Sunny,20,Normal,Weak,Yes
21,Rain,23,Normal,Weak,Yes
22,Overcast,24,High,Strong,Yes
23,Overcast,28,Normal,Weak,Yes
24,Rain,19,High,Strong,No
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')



if __name__ == '__main__':

    # 读取训练数据
    data = pd.read_csv('data-3.csv')
    inputs = data[['Weather', 'Temperature', 'Humidity', 'Wind']]
    labels = data['Play']
    inputs = pd.get_dummies(inputs)

    # 训练完全增长的决策树
    estimator = DecisionTreeClassifier(criterion='entropy')
    estimator.fit(inputs, labels)

    # 计算每次剪枝的 ccp_alpha 临界值以及对应的不纯度值
    pth = estimator.cost_complexity_pruning_path(inputs, labels)
    print('ccp_alphas:', pth['ccp_alphas'])
    print('impurities:', pth['impurities'])
ccp_alphas: [0.         0.0812528  0.08923789 0.20690843 0.22982195]
impurities: [0.         0.16250561 0.2517435  0.45865192 0.91829583]
  • 当选择的 α 值在 [0, 0.0812528] 时,整棵树的不纯度为 0
  • 当选择的 α 值在 (0.0812528,0.08923789] 时,整棵树的不纯度为 0.16250561
  • 当选择的 α 值在 (0.08923789,0.20690843] 时,整棵树的不纯度为 0.2517435
  • 当选择的 α 值在 (0.20690843,0.22982195] 时,整棵树的不纯度为 0.45865192
  • 当选择的 α 值在 (0.22982195,] 时,整棵树的不纯度为 0.91829583

至此,决策树的 CCP 剪枝算法就讲解完毕了,希望对你有所帮助。

未经允许不得转载:一亩三分地 » 代价复杂度剪枝(Cost-Complexity Pruning)
评论 (0)

1 + 2 =