6.1 训练和可视化决策树¶
为了理解决策树,我们建立一个决策树,然后看看它是如何做出预测的。
[1]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:, 2:] # 鸢尾花花瓣的长度和宽度
y = iris.target
X.shape, y.shape
[1]:
((150, 2), (150,))
[2]:
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)
[2]:
DecisionTreeClassifier(max_depth=2)
要将决策树可视化,首先使用export_graphviz()
方法输出一个图像定义文件
[6]:
from sklearn.tree import export_graphviz
export_graphviz(
tree_clf,
out_file="6-1_files/iris_tree.dot",
feature_names=iris.feature_names[2:],
class_names=iris.target_names,
rounded=True,
filled=True
)
[7]:
# 使用dot命令将.dot文件转换为.png图像文件
!dot -Tpng 6-1_files/iris_tree.dot -o 6-1_files/iris_tree.png