6.1 训练和可视化决策树

为了理解决策树,我们建立一个决策树,然后看看它是如何做出预测的。

[1]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

iris = load_iris()

X = iris.data[:, 2:]    # 鸢尾花花瓣的长度和宽度
y = iris.target

X.shape, y.shape
[1]:
((150, 2), (150,))
[2]:
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)
[2]:
DecisionTreeClassifier(max_depth=2)

要将决策树可视化,首先使用export_graphviz()方法输出一个图像定义文件

[6]:
from sklearn.tree import export_graphviz

export_graphviz(
    tree_clf,
    out_file="6-1_files/iris_tree.dot",
    feature_names=iris.feature_names[2:],
    class_names=iris.target_names,
    rounded=True,
    filled=True
    )
[7]:
# 使用dot命令将.dot文件转换为.png图像文件
!dot -Tpng 6-1_files/iris_tree.dot -o 6-1_files/iris_tree.png

tree_clf