在学习决策树原理之前,我们先感性的了解下决策树的构建和推理过程、以及 API 的使用。
1. 分类决策树
分类决策树基于训练数据构建一个树状结构,每个节点代表一个特征,每个分支代表一个可能的答案,最终叶节点代表一个分类标签。
训练数据:
决策树构建:
其中,0 代表否,1 代表是。
决策树预测:
2. 回归决策树
回归决策树基于训练数据构建一个树状结构,每个节点代表一个特征,每个分支代表一个预测走向,最终叶节点代表预测的结果。
训练数据:
决策树构建:
决策树预测:
3. 决策树使用
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import plot_tree import pickle import matplotlib.pyplot as plt from sklearn.datasets import load_iris import warnings warnings.filterwarnings('ignore') if __name__ == '__main__': # 鸢尾花数据集 data = load_iris() inputs = data['data'] labels = data['target'] # 决策树训练 estimator = DecisionTreeClassifier(random_state=42) estimator.fit(inputs, labels) # 决策树预测 - 输出标签 y_preds = estimator.predict(inputs) # 决策树预测 - 输出概率 y_proba = estimator.predict_proba(inputs) print(y_proba) # 决策树评估 acc = estimator.score(inputs, labels) print('Acc:', acc) # 决策树存储 pickle.dump(estimator, open('estimator.pkl', 'wb')) # 决策树加载 estimator = pickle.load(open('estimator.pkl', 'rb')) # 决策树可视化 plt.figure(figsize=(22, 10)) plot_tree(estimator, feature_names=data['feature_names'], filled=True, rounded=True, fontsize=14, precision=2) plt.show()
决策树可视化: