7.2 bagging和pasting

前面提到,获得不同种类分类器的方法之一是使用不同的训练算法。还有另一种方法是每个预测器使用的算法相同,但是在不同的训练集随机子集上进行训练。采样时如果将样本放回,这种方法叫作bagging(bootstrap aggregating的缩写,也叫自举汇聚法)。采样时样本不放回,这种方法则叫作pasting。

bagging和pasting都允许训练实例在多个预测器中被多次使用,但是只有bagging允许训练实例被同一个预测器多次采样。

一旦预测器训练完成,集成就可以通过简单地聚合所有预测器的预测来对新实例做出预测。聚合函数通常是统计法(即最多数的预测与硬投票分类器一样)用于分类,或是平均法用于回归。每个预测器单独的偏差都高于在原始训练集上训练的偏差,但是通过聚合,同时降低了偏差和方差。总体来说,最终结果是,与直接在原始训练集上训练的单个预测器相比,集成的偏差相近,但是方差更低。

你可以通过不同的CPU内核甚至不同的服务器并行地训练预测器。类似地,预测也可以并行。这正是bagging和pasting方法如此流行的原因之一,它们非常易于扩展。

7.2.1 Scikit-Learn中的bagging和pasting

[1]:
%matplotlib inline
[2]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")
[3]:
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

X, y = make_moons(n_samples=500, noise=0.30, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
len(X_train), len(X_test)
[3]:
(375, 125)
[4]:
# 这是一个bagging示例,如果想使用pasting(取样后不放回),只需要将bootstrap=False即可
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier

bag_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500, max_samples=100, bootstrap=True, n_jobs=-1)
bag_clf.fit(X_train, y_train)
y_pred = bag_clf.predict(X_test)
[5]:
from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_pred)
[5]:
0.912

如果基本分类器可以估计类别概率(如果它具有predict_proba()方法),则BaggingClassifier自动执行软投票而不是硬投票

[6]:
tree_clf = DecisionTreeClassifier()
tree_clf.fit(X_train, y_train)
y_pred = tree_clf.predict(X_test)
accuracy_score(y_test, y_pred)
[6]:
0.848
[7]:
# 绘制决策边界
from matplotlib.colors import ListedColormap

def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.45, -1, 1.5], alpha=0.5, contour=True):
    x1s = np.linspace(axes[0], axes[1], num=100)
    x2s = np.linspace(axes[2], axes[3], num=100)
    x1, x2 = np.meshgrid(x1s, x2s)
    X_new = np.c_[x1.ravel(), x2.ravel()]
    y_pred = clf.predict(X_new).reshape(x1.shape)
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)
    if contour:
        custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
        plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)
    plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo", alpha=alpha)
    plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs", alpha=alpha)
    plt.axis(axes)
    plt.xlabel(r"$x_1$", fontsize=18)
    plt.ylabel(r"$x_2$", fontsize=18, rotation=0)
[10]:
from matplotlib import font_manager

fontP = font_manager.FontProperties(fname="./../fonts/Arial Unicode.ttf")
# fontP.set_family('monospace')
fontP.set_size(14)

fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)
plt.sca(axes[0])
plot_decision_boundary(tree_clf, X, y)
plt.title("决策树", fontproperties=fontP)
plt.sca(axes[1])
plot_decision_boundary(bag_clf, X, y)
plt.title("使用bagging的决策树", fontproperties=fontP)
plt.ylabel("")
plt.show()
../_images/chapter7_7-2_bagging_pasting_12_0.svg
[11]:
bag_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500, max_samples=100, bootstrap=True, n_jobs=-1)
bag_clf.fit(X_train, y_train)

past_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500, max_samples=100, bootstrap=False, n_jobs=-1)
past_clf.fit(X_train, y_train)
[11]:
BaggingClassifier(base_estimator=DecisionTreeClassifier(), bootstrap=False,
                  max_samples=100, n_estimators=500, n_jobs=-1)
[12]:
from matplotlib import font_manager

fontP = font_manager.FontProperties(fname="./../fonts/Arial Unicode.ttf")
# fontP.set_family('monospace')
fontP.set_size(14)

fix, axes = plt.subplots(ncols=2, figsize=(10,4), sharey=True)
plt.sca(axes[0])
plot_decision_boundary(bag_clf, X, y)
plt.title("使用bagging的决策树", fontproperties=fontP)
plt.sca(axes[1])
plot_decision_boundary(past_clf, X, y)
plt.title("使用pasting的决策树", fontproperties=fontP)
plt.ylabel("")
plt.show()
../_images/chapter7_7-2_bagging_pasting_14_0.svg

由于自举法给每个预测器的训练子集引入了更高的多样性,所以最后bagging比pasting的偏差略高,但这也意味着预测器之间的关联度更低,所以集成的方差降低。总之,bagging生成的模型通常更好,这也就是为什么它更受欢迎。但是,如果你有充足的时间和CPU资源,可以使用交叉验证来对bagging和pasting的结果进行评估,再做出最合适的选择。

7.2.2 包外评估

对于任意给定的预测器,使用bagging,有些实例可能会被采样多次,而有些实例则可能根本不被采样。BaggingClassifier默认采样m个训练实例,然后放回样本(bootstrap=True),m是训练集的大小。这意味着对每个预测器来说,平均只对63%的训练实例进行采样。剩余37%未被采样的训练实例称为包外(oob)实例。注意,对所有预测器来说,这是不一样的37%。

随着m的增长,该比率接近1–exp(-1)≈63.212%

由于预测器在训练过程中从未看到oob实例,因此可以在这些实例上进行评估,而无须单独的验证集。你可以通过平均每个预测器的oob评估来评估整体。

在Scikit-Learn中,创建BaggingClassifier时,设置oob_score=True就可以请求在训练结束后自动进行包外评估。

[13]:
bag_clf = BaggingClassifier(DecisionTreeClassifier(), n_estimators=500, bootstrap=True, n_jobs=-1, oob_score=True)
bag_clf.fit(X_train, y_train)
bag_clf.oob_score_
[13]:
0.904
[14]:
from sklearn.metrics import accuracy_score
accuracy_score(y_test, bag_clf.predict(X_test))
[14]:
0.904

每个训练实例的包外决策函数也可以通过变量oob_decision_function_获得。本例中(基本预测器有predict_proba()方法),决策函数返回的是每个实例的类别概率。

[15]:
bag_clf.oob_decision_function_
[15]:
array([[0.43333333, 0.56666667],
       [0.36781609, 0.63218391],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.04812834, 0.95187166],
       [0.33333333, 0.66666667],
       [0.01657459, 0.98342541],
       [1.        , 0.        ],
       [0.97714286, 0.02285714],
       [0.76963351, 0.23036649],
       [0.01136364, 0.98863636],
       [0.8097561 , 0.1902439 ],
       [0.87027027, 0.12972973],
       [0.96491228, 0.03508772],
       [0.05084746, 0.94915254],
       [0.        , 1.        ],
       [0.98404255, 0.01595745],
       [0.9273743 , 0.0726257 ],
       [1.        , 0.        ],
       [0.00571429, 0.99428571],
       [0.28021978, 0.71978022],
       [0.91304348, 0.08695652],
       [1.        , 0.        ],
       [0.98235294, 0.01764706],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.67213115, 0.32786885],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.19371728, 0.80628272],
       [1.        , 0.        ],
       [0.00578035, 0.99421965],
       [0.43352601, 0.56647399],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.25730994, 0.74269006],
       [0.40659341, 0.59340659],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.00574713, 0.99425287],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.0201005 , 0.9798995 ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.99468085, 0.00531915],
       [0.92485549, 0.07514451],
       [0.97297297, 0.02702703],
       [0.95789474, 0.04210526],
       [0.00571429, 0.99428571],
       [0.05555556, 0.94444444],
       [0.98360656, 0.01639344],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.01098901, 0.98901099],
       [0.98314607, 0.01685393],
       [0.78977273, 0.21022727],
       [0.38624339, 0.61375661],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.71957672, 0.28042328],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.82513661, 0.17486339],
       [1.        , 0.        ],
       [0.59770115, 0.40229885],
       [0.11956522, 0.88043478],
       [0.58201058, 0.41798942],
       [0.85882353, 0.14117647],
       [0.        , 1.        ],
       [0.15555556, 0.84444444],
       [0.87570621, 0.12429379],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.05940594, 0.94059406],
       [0.035     , 0.965     ],
       [0.27225131, 0.72774869],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.80239521, 0.19760479],
       [0.01507538, 0.98492462],
       [0.        , 1.        ],
       [0.01069519, 0.98930481],
       [0.18905473, 0.81094527],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.96774194, 0.03225806],
       [0.81621622, 0.18378378],
       [0.00534759, 0.99465241],
       [1.        , 0.        ],
       [0.19889503, 0.80110497],
       [0.59509202, 0.40490798],
       [0.        , 1.        ],
       [0.06703911, 0.93296089],
       [0.48863636, 0.51136364],
       [1.        , 0.        ],
       [0.02298851, 0.97701149],
       [1.        , 0.        ],
       [0.24175824, 0.75824176],
       [0.44886364, 0.55113636],
       [1.        , 0.        ],
       [0.01123596, 0.98876404],
       [0.98076923, 0.01923077],
       [0.25      , 0.75      ],
       [0.93373494, 0.06626506],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.79166667, 0.20833333],
       [1.        , 0.        ],
       [0.03      , 0.97      ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.97849462, 0.02150538],
       [1.        , 0.        ],
       [0.01197605, 0.98802395],
       [0.96039604, 0.03960396],
       [0.99456522, 0.00543478],
       [0.02840909, 0.97159091],
       [0.22340426, 0.77659574],
       [0.92857143, 0.07142857],
       [0.26900585, 0.73099415],
       [0.98901099, 0.01098901],
       [0.        , 1.        ],
       [0.00543478, 0.99456522],
       [0.70175439, 0.29824561],
       [0.37628866, 0.62371134],
       [0.40223464, 0.59776536],
       [0.85964912, 0.14035088],
       [0.93121693, 0.06878307],
       [0.05      , 0.95      ],
       [0.8241206 , 0.1758794 ],
       [0.0052356 , 0.9947644 ],
       [0.        , 1.        ],
       [0.02105263, 0.97894737],
       [0.9673913 , 0.0326087 ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.97      , 0.03      ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.38219895, 0.61780105],
       [0.30319149, 0.69680851],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.36170213, 0.63829787],
       [1.        , 0.        ],
       [0.99473684, 0.00526316],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.01117318, 0.98882682],
       [0.        , 1.        ],
       [0.97905759, 0.02094241],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.6972973 , 0.3027027 ],
       [0.93333333, 0.06666667],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.98837209, 0.01162791],
       [1.        , 0.        ],
       [0.00492611, 0.99507389],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.09039548, 0.90960452],
       [1.        , 0.        ],
       [0.04545455, 0.95454545],
       [0.00571429, 0.99428571],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.02673797, 0.97326203],
       [1.        , 0.        ],
       [0.95027624, 0.04972376],
       [0.78804348, 0.21195652],
       [0.6576087 , 0.3423913 ],
       [0.00518135, 0.99481865],
       [0.11976048, 0.88023952],
       [1.        , 0.        ],
       [0.95054945, 0.04945055],
       [0.97282609, 0.02717391],
       [1.        , 0.        ],
       [0.00518135, 0.99481865],
       [0.        , 1.        ],
       [0.44067797, 0.55932203],
       [0.85714286, 0.14285714],
       [0.        , 1.        ],
       [0.00515464, 0.99484536],
       [0.99462366, 0.00537634],
       [0.01829268, 0.98170732],
       [0.        , 1.        ],
       [0.9375    , 0.0625    ],
       [0.        , 1.        ],
       [0.23350254, 0.76649746],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.99481865, 0.00518135],
       [0.78609626, 0.21390374],
       [1.        , 0.        ],
       [0.00534759, 0.99465241],
       [0.07608696, 0.92391304],
       [1.        , 0.        ],
       [0.02139037, 0.97860963],
       [0.        , 1.        ],
       [0.05487805, 0.94512195],
       [0.98843931, 0.01156069],
       [0.79274611, 0.20725389],
       [0.        , 1.        ],
       [0.87939698, 0.12060302],
       [1.        , 0.        ],
       [0.21022727, 0.78977273],
       [0.2173913 , 0.7826087 ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.25301205, 0.74698795],
       [0.93264249, 0.06735751],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.99473684, 0.00526316],
       [0.        , 1.        ],
       [0.5480226 , 0.4519774 ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.11797753, 0.88202247],
       [0.08988764, 0.91011236],
       [0.98974359, 0.01025641],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.44508671, 0.55491329],
       [0.08241758, 0.91758242],
       [0.5923913 , 0.4076087 ],
       [0.65957447, 0.34042553],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.63783784, 0.36216216],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.21875   , 0.78125   ],
       [0.83333333, 0.16666667],
       [0.06532663, 0.93467337],
       [1.        , 0.        ],
       [0.83084577, 0.16915423],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.07734807, 0.92265193],
       [0.01648352, 0.98351648],
       [0.        , 1.        ],
       [0.99473684, 0.00526316],
       [0.9017341 , 0.0982659 ],
       [0.195     , 0.805     ],
       [0.92391304, 0.07608696],
       [0.00520833, 0.99479167],
       [0.66834171, 0.33165829],
       [0.08241758, 0.91758242],
       [0.97714286, 0.02285714],
       [0.74444444, 0.25555556],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.93513514, 0.06486486],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.29378531, 0.70621469],
       [0.98404255, 0.01595745],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.01081081, 0.98918919],
       [0.83425414, 0.16574586],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.75595238, 0.24404762],
       [0.95408163, 0.04591837],
       [1.        , 0.        ],
       [0.66666667, 0.33333333],
       [0.48538012, 0.51461988],
       [0.        , 1.        ],
       [0.91712707, 0.08287293],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.92972973, 0.07027027],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.71764706, 0.28235294],
       [0.15168539, 0.84831461],
       [0.49714286, 0.50285714],
       [0.17977528, 0.82022472],
       [0.        , 1.        ],
       [0.87      , 0.13      ],
       [0.80597015, 0.19402985],
       [0.01595745, 0.98404255],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.01980198, 0.98019802],
       [0.94413408, 0.05586592],
       [0.95135135, 0.04864865],
       [1.        , 0.        ],
       [0.53157895, 0.46842105],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.96891192, 0.03108808],
       [0.03278689, 0.96721311],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.93888889, 0.06111111],
       [0.00518135, 0.99481865],
       [0.09444444, 0.90555556],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.02645503, 0.97354497],
       [1.        , 0.        ],
       [0.08426966, 0.91573034],
       [0.        , 1.        ],
       [0.01104972, 0.98895028],
       [0.        , 1.        ],
       [0.35602094, 0.64397906],
       [0.0621118 , 0.9378882 ],
       [0.21875   , 0.78125   ],
       [1.        , 0.        ],
       [0.98342541, 0.01657459],
       [0.19565217, 0.80434783],
       [0.97849462, 0.02150538],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.95833333, 0.04166667],
       [0.32984293, 0.67015707],
       [0.98113208, 0.01886792],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.04117647, 0.95882353],
       [0.99484536, 0.00515464],
       [1.        , 0.        ],
       [0.05617978, 0.94382022],
       [0.68421053, 0.31578947]])