7.2 bagging和pasting¶
前面提到,获得不同种类分类器的方法之一是使用不同的训练算法。还有另一种方法是每个预测器使用的算法相同,但是在不同的训练集随机子集上进行训练。采样时如果将样本放回,这种方法叫作bagging(bootstrap aggregating的缩写,也叫自举汇聚法)。采样时样本不放回,这种方法则叫作pasting。
bagging和pasting都允许训练实例在多个预测器中被多次使用,但是只有bagging允许训练实例被同一个预测器多次采样。
一旦预测器训练完成,集成就可以通过简单地聚合所有预测器的预测来对新实例做出预测。聚合函数通常是统计法(即最多数的预测与硬投票分类器一样)用于分类,或是平均法用于回归。每个预测器单独的偏差都高于在原始训练集上训练的偏差,但是通过聚合,同时降低了偏差和方差。总体来说,最终结果是,与直接在原始训练集上训练的单个预测器相比,集成的偏差相近,但是方差更低。
你可以通过不同的CPU内核甚至不同的服务器并行地训练预测器。类似地,预测也可以并行。这正是bagging和pasting方法如此流行的原因之一,它们非常易于扩展。
7.2.1 Scikit-Learn中的bagging和pasting¶
[1]:
%matplotlib inline
[2]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")
[3]:
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
X, y = make_moons(n_samples=500, noise=0.30, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
len(X_train), len(X_test)
[3]:
(375, 125)
[4]:
# 这是一个bagging示例,如果想使用pasting(取样后不放回),只需要将bootstrap=False即可
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
bag_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500, max_samples=100, bootstrap=True, n_jobs=-1)
bag_clf.fit(X_train, y_train)
y_pred = bag_clf.predict(X_test)
[5]:
from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_pred)
[5]:
0.912
如果基本分类器可以估计类别概率(如果它具有predict_proba()方法),则BaggingClassifier自动执行软投票而不是硬投票
[6]:
tree_clf = DecisionTreeClassifier()
tree_clf.fit(X_train, y_train)
y_pred = tree_clf.predict(X_test)
accuracy_score(y_test, y_pred)
[6]:
0.848
[7]:
# 绘制决策边界
from matplotlib.colors import ListedColormap
def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.45, -1, 1.5], alpha=0.5, contour=True):
x1s = np.linspace(axes[0], axes[1], num=100)
x2s = np.linspace(axes[2], axes[3], num=100)
x1, x2 = np.meshgrid(x1s, x2s)
X_new = np.c_[x1.ravel(), x2.ravel()]
y_pred = clf.predict(X_new).reshape(x1.shape)
custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)
if contour:
custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)
plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo", alpha=alpha)
plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs", alpha=alpha)
plt.axis(axes)
plt.xlabel(r"$x_1$", fontsize=18)
plt.ylabel(r"$x_2$", fontsize=18, rotation=0)
[10]:
from matplotlib import font_manager
fontP = font_manager.FontProperties(fname="./../fonts/Arial Unicode.ttf")
# fontP.set_family('monospace')
fontP.set_size(14)
fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)
plt.sca(axes[0])
plot_decision_boundary(tree_clf, X, y)
plt.title("决策树", fontproperties=fontP)
plt.sca(axes[1])
plot_decision_boundary(bag_clf, X, y)
plt.title("使用bagging的决策树", fontproperties=fontP)
plt.ylabel("")
plt.show()
[11]:
bag_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500, max_samples=100, bootstrap=True, n_jobs=-1)
bag_clf.fit(X_train, y_train)
past_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500, max_samples=100, bootstrap=False, n_jobs=-1)
past_clf.fit(X_train, y_train)
[11]:
BaggingClassifier(base_estimator=DecisionTreeClassifier(), bootstrap=False,
max_samples=100, n_estimators=500, n_jobs=-1)
[12]:
from matplotlib import font_manager
fontP = font_manager.FontProperties(fname="./../fonts/Arial Unicode.ttf")
# fontP.set_family('monospace')
fontP.set_size(14)
fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)
plt.sca(axes[0])
plot_decision_boundary(bag_clf, X, y)
plt.title("使用bagging的决策树", fontproperties=fontP)
plt.sca(axes[1])
plot_decision_boundary(past_clf, X, y)
plt.title("使用pasting的决策树", fontproperties=fontP)
plt.ylabel("")
plt.show()
由于自举法给每个预测器的训练子集引入了更高的多样性,所以最后bagging比pasting的偏差略高,但这也意味着预测器之间的关联度更低,所以集成的方差降低。总之,bagging生成的模型通常更好,这也就是为什么它更受欢迎。但是,如果你有充足的时间和CPU资源,可以使用交叉验证来对bagging和pasting的结果进行评估,再做出最合适的选择。
7.2.2 包外评估¶
对于任意给定的预测器,使用bagging,有些实例可能会被采样多次,而有些实例则可能根本不被采样。BaggingClassifier默认采样m个训练实例,然后放回样本(bootstrap=True),m是训练集的大小。这意味着对每个预测器来说,平均只对63%的训练实例进行采样。剩余37%未被采样的训练实例称为包外(oob)实例。注意,对所有预测器来说,这是不一样的37%。
随着m的增长,该比率接近1–exp(-1)≈63.212%
由于预测器在训练过程中从未看到oob实例,因此可以在这些实例上进行评估,而无须单独的验证集。你可以通过平均每个预测器的oob评估来评估整体。
在Scikit-Learn中,创建BaggingClassifier时,设置oob_score=True就可以请求在训练结束后自动进行包外评估。
[13]:
bag_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500, bootstrap=True, n_jobs=-1, oob_score=True)
bag_clf.fit(X_train, y_train)
bag_clf.oob_score_
[13]:
0.904
[14]:
from sklearn.metrics import accuracy_score
accuracy_score(y_test, bag_clf.predict(X_test))
[14]:
0.904
每个训练实例的包外决策函数也可以通过变量oob_decision_function_获得。本例中(基本预测器有predict_proba()方法),决策函数返回的是每个实例的类别概率。
[15]:
bag_clf.oob_decision_function_
[15]:
array([[0.43333333, 0.56666667],
[0.36781609, 0.63218391],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.04812834, 0.95187166],
[0.33333333, 0.66666667],
[0.01657459, 0.98342541],
[1. , 0. ],
[0.97714286, 0.02285714],
[0.76963351, 0.23036649],
[0.01136364, 0.98863636],
[0.8097561 , 0.1902439 ],
[0.87027027, 0.12972973],
[0.96491228, 0.03508772],
[0.05084746, 0.94915254],
[0. , 1. ],
[0.98404255, 0.01595745],
[0.9273743 , 0.0726257 ],
[1. , 0. ],
[0.00571429, 0.99428571],
[0.28021978, 0.71978022],
[0.91304348, 0.08695652],
[1. , 0. ],
[0.98235294, 0.01764706],
[0. , 1. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0.67213115, 0.32786885],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.19371728, 0.80628272],
[1. , 0. ],
[0.00578035, 0.99421965],
[0.43352601, 0.56647399],
[0. , 1. ],
[1. , 0. ],
[0.25730994, 0.74269006],
[0.40659341, 0.59340659],
[1. , 0. ],
[1. , 0. ],
[0.00574713, 0.99425287],
[1. , 0. ],
[1. , 0. ],
[0.0201005 , 0.9798995 ],
[1. , 0. ],
[0. , 1. ],
[0.99468085, 0.00531915],
[0.92485549, 0.07514451],
[0.97297297, 0.02702703],
[0.95789474, 0.04210526],
[0.00571429, 0.99428571],
[0.05555556, 0.94444444],
[0.98360656, 0.01639344],
[0. , 1. ],
[0. , 1. ],
[0.01098901, 0.98901099],
[0.98314607, 0.01685393],
[0.78977273, 0.21022727],
[0.38624339, 0.61375661],
[1. , 0. ],
[0. , 1. ],
[0.71957672, 0.28042328],
[1. , 0. ],
[1. , 0. ],
[0.82513661, 0.17486339],
[1. , 0. ],
[0.59770115, 0.40229885],
[0.11956522, 0.88043478],
[0.58201058, 0.41798942],
[0.85882353, 0.14117647],
[0. , 1. ],
[0.15555556, 0.84444444],
[0.87570621, 0.12429379],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0.05940594, 0.94059406],
[0.035 , 0.965 ],
[0.27225131, 0.72774869],
[1. , 0. ],
[0. , 1. ],
[0.80239521, 0.19760479],
[0.01507538, 0.98492462],
[0. , 1. ],
[0.01069519, 0.98930481],
[0.18905473, 0.81094527],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
[0.96774194, 0.03225806],
[0.81621622, 0.18378378],
[0.00534759, 0.99465241],
[1. , 0. ],
[0.19889503, 0.80110497],
[0.59509202, 0.40490798],
[0. , 1. ],
[0.06703911, 0.93296089],
[0.48863636, 0.51136364],
[1. , 0. ],
[0.02298851, 0.97701149],
[1. , 0. ],
[0.24175824, 0.75824176],
[0.44886364, 0.55113636],
[1. , 0. ],
[0.01123596, 0.98876404],
[0.98076923, 0.01923077],
[0.25 , 0.75 ],
[0.93373494, 0.06626506],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.79166667, 0.20833333],
[1. , 0. ],
[0.03 , 0.97 ],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[0.97849462, 0.02150538],
[1. , 0. ],
[0.01197605, 0.98802395],
[0.96039604, 0.03960396],
[0.99456522, 0.00543478],
[0.02840909, 0.97159091],
[0.22340426, 0.77659574],
[0.92857143, 0.07142857],
[0.26900585, 0.73099415],
[0.98901099, 0.01098901],
[0. , 1. ],
[0.00543478, 0.99456522],
[0.70175439, 0.29824561],
[0.37628866, 0.62371134],
[0.40223464, 0.59776536],
[0.85964912, 0.14035088],
[0.93121693, 0.06878307],
[0.05 , 0.95 ],
[0.8241206 , 0.1758794 ],
[0.0052356 , 0.9947644 ],
[0. , 1. ],
[0.02105263, 0.97894737],
[0.9673913 , 0.0326087 ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[1. , 0. ],
[0.97 , 0.03 ],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0.38219895, 0.61780105],
[0.30319149, 0.69680851],
[0. , 1. ],
[0. , 1. ],
[0.36170213, 0.63829787],
[1. , 0. ],
[0.99473684, 0.00526316],
[0. , 1. ],
[1. , 0. ],
[0.01117318, 0.98882682],
[0. , 1. ],
[0.97905759, 0.02094241],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0.6972973 , 0.3027027 ],
[0.93333333, 0.06666667],
[0. , 1. ],
[1. , 0. ],
[0.98837209, 0.01162791],
[1. , 0. ],
[0.00492611, 0.99507389],
[0. , 1. ],
[1. , 0. ],
[0.09039548, 0.90960452],
[1. , 0. ],
[0.04545455, 0.95454545],
[0.00571429, 0.99428571],
[1. , 0. ],
[0. , 1. ],
[0.02673797, 0.97326203],
[1. , 0. ],
[0.95027624, 0.04972376],
[0.78804348, 0.21195652],
[0.6576087 , 0.3423913 ],
[0.00518135, 0.99481865],
[0.11976048, 0.88023952],
[1. , 0. ],
[0.95054945, 0.04945055],
[0.97282609, 0.02717391],
[1. , 0. ],
[0.00518135, 0.99481865],
[0. , 1. ],
[0.44067797, 0.55932203],
[0.85714286, 0.14285714],
[0. , 1. ],
[0.00515464, 0.99484536],
[0.99462366, 0.00537634],
[0.01829268, 0.98170732],
[0. , 1. ],
[0.9375 , 0.0625 ],
[0. , 1. ],
[0.23350254, 0.76649746],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.99481865, 0.00518135],
[0.78609626, 0.21390374],
[1. , 0. ],
[0.00534759, 0.99465241],
[0.07608696, 0.92391304],
[1. , 0. ],
[0.02139037, 0.97860963],
[0. , 1. ],
[0.05487805, 0.94512195],
[0.98843931, 0.01156069],
[0.79274611, 0.20725389],
[0. , 1. ],
[0.87939698, 0.12060302],
[1. , 0. ],
[0.21022727, 0.78977273],
[0.2173913 , 0.7826087 ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
[0.25301205, 0.74698795],
[0.93264249, 0.06735751],
[0. , 1. ],
[1. , 0. ],
[0.99473684, 0.00526316],
[0. , 1. ],
[0.5480226 , 0.4519774 ],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.11797753, 0.88202247],
[0.08988764, 0.91011236],
[0.98974359, 0.01025641],
[0. , 1. ],
[1. , 0. ],
[0.44508671, 0.55491329],
[0.08241758, 0.91758242],
[0.5923913 , 0.4076087 ],
[0.65957447, 0.34042553],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0. , 1. ],
[0.63783784, 0.36216216],
[0. , 1. ],
[1. , 0. ],
[0.21875 , 0.78125 ],
[0.83333333, 0.16666667],
[0.06532663, 0.93467337],
[1. , 0. ],
[0.83084577, 0.16915423],
[0. , 1. ],
[0. , 1. ],
[0.07734807, 0.92265193],
[0.01648352, 0.98351648],
[0. , 1. ],
[0.99473684, 0.00526316],
[0.9017341 , 0.0982659 ],
[0.195 , 0.805 ],
[0.92391304, 0.07608696],
[0.00520833, 0.99479167],
[0.66834171, 0.33165829],
[0.08241758, 0.91758242],
[0.97714286, 0.02285714],
[0.74444444, 0.25555556],
[0. , 1. ],
[1. , 0. ],
[0.93513514, 0.06486486],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0.29378531, 0.70621469],
[0.98404255, 0.01595745],
[1. , 0. ],
[0. , 1. ],
[0.01081081, 0.98918919],
[0.83425414, 0.16574586],
[0. , 1. ],
[1. , 0. ],
[0.75595238, 0.24404762],
[0.95408163, 0.04591837],
[1. , 0. ],
[0.66666667, 0.33333333],
[0.48538012, 0.51461988],
[0. , 1. ],
[0.91712707, 0.08287293],
[0. , 1. ],
[1. , 0. ],
[0.92972973, 0.07027027],
[1. , 0. ],
[1. , 0. ],
[0.71764706, 0.28235294],
[0.15168539, 0.84831461],
[0.49714286, 0.50285714],
[0.17977528, 0.82022472],
[0. , 1. ],
[0.87 , 0.13 ],
[0.80597015, 0.19402985],
[0.01595745, 0.98404255],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0.01980198, 0.98019802],
[0.94413408, 0.05586592],
[0.95135135, 0.04864865],
[1. , 0. ],
[0.53157895, 0.46842105],
[1. , 0. ],
[0. , 1. ],
[0.96891192, 0.03108808],
[0.03278689, 0.96721311],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0.93888889, 0.06111111],
[0.00518135, 0.99481865],
[0.09444444, 0.90555556],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0.02645503, 0.97354497],
[1. , 0. ],
[0.08426966, 0.91573034],
[0. , 1. ],
[0.01104972, 0.98895028],
[0. , 1. ],
[0.35602094, 0.64397906],
[0.0621118 , 0.9378882 ],
[0.21875 , 0.78125 ],
[1. , 0. ],
[0.98342541, 0.01657459],
[0.19565217, 0.80434783],
[0.97849462, 0.02150538],
[0. , 1. ],
[0. , 1. ],
[1. , 0. ],
[0.95833333, 0.04166667],
[0.32984293, 0.67015707],
[0.98113208, 0.01886792],
[1. , 0. ],
[0. , 1. ],
[1. , 0. ],
[0. , 1. ],
[0.04117647, 0.95882353],
[0.99484536, 0.00515464],
[1. , 0. ],
[0.05617978, 0.94382022],
[0.68421053, 0.31578947]])