5-5 课后练习9

在MNIST数据集上训练SVM分类器。由于SVM分类器是一个二元分类器,所以你需要使用一对多为10个数字进行分类。你可能还需要使用小型验证集来调整超参数以加快进度。看看最后的准确率是多少?

我们可以使用train_test_split()方法拆分训练集和测试集,不过在往往大家使用前60000个图像作为训练集,使用后面的10000个图像作为测试集,这样的拆分是为了方便和其他的模型进行性能比较

[1]:
from sklearn.datasets import fetch_openml
import matplotlib
import matplotlib.pyplot as plt
[2]:
mnist = fetch_openml('mnist_784', version=1, cache=True)
mnist.keys()
[2]:
dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])
[3]:
mnist.DESCR
[3]:
"**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org."
[4]:
import numpy as np
X = mnist['data']
y = mnist['target'].astype(np.uint8)

X_train = X[:60000]
y_train = y[:60000]
X_test = X[60000:]
y_test = y[60000:]
[5]:
y[:10]
[5]:
array([5, 0, 4, 1, 9, 2, 1, 3, 1, 4], dtype=uint8)

许多算法对训练集的示例的顺序是非常敏感的,所以一般需要先将训练数据集洗牌。不过该数据集已经清晰过了,不需要在清晰了。

首先,使用最最简单的SVM Classifier分类器,这个分类器自动实现了了OvA(OvR)策略,无需额外做任何工作

[6]:
from sklearn.svm import LinearSVC
lin_clf = LinearSVC(random_state=42)
lin_clf.fit(X_train, y_train)
[6]:
LinearSVC(random_state=42)

在训练集上测量模型的acc,由于我们还没有选择和训练最终的模型,所以我们暂时不要在测试集上进行测量

[7]:
from sklearn.metrics import accuracy_score
y_pred = lin_clf.predict(X_train)
accuracy_score(y_train, y_pred)
[7]:
0.8348666666666666

对于MNIST来说,83.38%的acc是很差的,线性模型对于MNIST来说太差了,不过我们可以尝试一下将数据进行一下缩放

[8]:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float32))
X_test_scaled = scaler.transform(X_test.astype(np.float32))
[10]:
lin_clf = LinearSVC(random_state=42)
lin_clf.fit(X_train_scaled, y_train)

[10]:
LinearSVC(random_state=42)
[11]:
y_pred = lin_clf.predict(X_train_scaled)
accuracy_score(y_train, y_pred)
[11]:
0.9217333333333333

进行一个简单的正则化一下子acc提高到了92.17%,但是都是MNIST这样的数据集来说依然不够好。

如果我们要使用一个SVM的话,需要有一个核,可以尝试一个SVC模型的RBF核

[12]:
from sklearn.svm import SVC

svm_clf = SVC(gamma='scale')
svm_clf.fit(X_train_scaled[:10000], y_train[:10000])

[12]:
SVC()
[13]:
y_pred = lin_clf.predict(X_train_scaled)
accuracy_score(y_train, y_pred)
[13]:
0.9217333333333333

上面我们使用的训练数据只是LinearSVC的六分之一,但是精确度却相当,所以是有希望进行微调来优化的。可以尝试一下随机搜索。

[14]:
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import reciprocal, uniform

param_distributions = {"gamma": reciprocal(0.001, 0.1), "C": uniform(1, 10)}
rnd_search_cv = RandomizedSearchCV(svm_clf, param_distributions=param_distributions, n_iter=10, verbose=2, cv=3)
rnd_search_cv.fit(X_train_scaled[:10000], y_train[:10000])
Fitting 3 folds for each of 10 candidates, totalling 30 fits
[CV] C=10.857913608904475, gamma=0.05400428325793935 .................
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[CV] .. C=10.857913608904475, gamma=0.05400428325793935, total= 1.6min
[CV] C=10.857913608904475, gamma=0.05400428325793935 .................
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  1.6min remaining:    0.0s
[CV] .. C=10.857913608904475, gamma=0.05400428325793935, total= 1.8min
[CV] C=10.857913608904475, gamma=0.05400428325793935 .................
[CV] .. C=10.857913608904475, gamma=0.05400428325793935, total= 2.0min
[CV] C=7.779125198526879, gamma=0.0016773371141783062 ................
[CV] . C=7.779125198526879, gamma=0.0016773371141783062, total=  30.0s
[CV] C=7.779125198526879, gamma=0.0016773371141783062 ................
[CV] . C=7.779125198526879, gamma=0.0016773371141783062, total=  29.1s
[CV] C=7.779125198526879, gamma=0.0016773371141783062 ................
[CV] . C=7.779125198526879, gamma=0.0016773371141783062, total=  30.7s
[CV] C=6.578025432118354, gamma=0.010544769404567864 .................
[CV] .. C=6.578025432118354, gamma=0.010544769404567864, total= 1.6min
[CV] C=6.578025432118354, gamma=0.010544769404567864 .................
[CV] .. C=6.578025432118354, gamma=0.010544769404567864, total= 1.6min
[CV] C=6.578025432118354, gamma=0.010544769404567864 .................
[CV] .. C=6.578025432118354, gamma=0.010544769404567864, total= 1.5min
[CV] C=10.744611313642467, gamma=0.0011243399627118532 ...............
[CV]  C=10.744611313642467, gamma=0.0011243399627118532, total=  27.2s
[CV] C=10.744611313642467, gamma=0.0011243399627118532 ...............
[CV]  C=10.744611313642467, gamma=0.0011243399627118532, total=  26.1s
[CV] C=10.744611313642467, gamma=0.0011243399627118532 ...............
[CV]  C=10.744611313642467, gamma=0.0011243399627118532, total=  26.5s
[CV] C=1.1193617536785896, gamma=0.005346759068249622 ................
[CV] . C=1.1193617536785896, gamma=0.005346759068249622, total= 1.5min
[CV] C=1.1193617536785896, gamma=0.005346759068249622 ................
[CV] . C=1.1193617536785896, gamma=0.005346759068249622, total= 1.0min
[CV] C=1.1193617536785896, gamma=0.005346759068249622 ................
[CV] . C=1.1193617536785896, gamma=0.005346759068249622, total=  57.3s
[CV] C=5.981586552596838, gamma=0.016370682035152535 .................
[CV] .. C=5.981586552596838, gamma=0.016370682035152535, total= 1.5min
[CV] C=5.981586552596838, gamma=0.016370682035152535 .................
[CV] .. C=5.981586552596838, gamma=0.016370682035152535, total= 1.6min
[CV] C=5.981586552596838, gamma=0.016370682035152535 .................
[CV] .. C=5.981586552596838, gamma=0.016370682035152535, total= 1.5min
[CV] C=8.64983329530747, gamma=0.0015789278410968156 .................
[CV] .. C=8.64983329530747, gamma=0.0015789278410968156, total=  27.2s
[CV] C=8.64983329530747, gamma=0.0015789278410968156 .................
[CV] .. C=8.64983329530747, gamma=0.0015789278410968156, total=  27.2s
[CV] C=8.64983329530747, gamma=0.0015789278410968156 .................
[CV] .. C=8.64983329530747, gamma=0.0015789278410968156, total=  26.8s
[CV] C=1.6078300040855256, gamma=0.052521820356754796 ................
[CV] . C=1.6078300040855256, gamma=0.052521820356754796, total= 1.6min
[CV] C=1.6078300040855256, gamma=0.052521820356754796 ................
[CV] . C=1.6078300040855256, gamma=0.052521820356754796, total= 1.6min
[CV] C=1.6078300040855256, gamma=0.052521820356754796 ................
[CV] . C=1.6078300040855256, gamma=0.052521820356754796, total= 1.6min
[CV] C=10.008119385975288, gamma=0.002261735005964153 ................
[CV] . C=10.008119385975288, gamma=0.002261735005964153, total=  32.5s
[CV] C=10.008119385975288, gamma=0.002261735005964153 ................
[CV] . C=10.008119385975288, gamma=0.002261735005964153, total=  32.4s
[CV] C=10.008119385975288, gamma=0.002261735005964153 ................
[CV] . C=10.008119385975288, gamma=0.002261735005964153, total=  31.8s
[CV] C=9.884730797744202, gamma=0.017833926900964567 .................
[CV] .. C=9.884730797744202, gamma=0.017833926900964567, total= 1.5min
[CV] C=9.884730797744202, gamma=0.017833926900964567 .................
[CV] .. C=9.884730797744202, gamma=0.017833926900964567, total= 1.6min
[CV] C=9.884730797744202, gamma=0.017833926900964567 .................
[CV] .. C=9.884730797744202, gamma=0.017833926900964567, total= 1.6min
[Parallel(n_jobs=1)]: Done  30 out of  30 | elapsed: 33.5min finished
[14]:
RandomizedSearchCV(cv=3, estimator=SVC(),
                   param_distributions={'C': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7fdb6809e710>,
                                        'gamma': <scipy.stats._distn_infrastructure.rv_frozen object at 0x7fdb6809e050>},
                   verbose=2)
[15]:
rnd_search_cv.best_estimator_
[15]:
SVC(C=10.744611313642467, gamma=0.0011243399627118532)
[16]:
rnd_search_cv.best_score_
[16]:
0.9389998087991162

这里的acc看起来比较低,但是不要忘了,我们训练的图像只有10000个。下面在整个训练集上进行训练

[17]:
rnd_search_cv.best_estimator_.fit(X_train_scaled, y_train)
[17]:
SVC(C=10.744611313642467, gamma=0.0011243399627118532)
[18]:
y_pred = rnd_search_cv.best_estimator_.predict(X_train_scaled)
accuracy_score(y_train, y_pred)
[18]:
0.9990166666666667

现在看起来不错了,可以在测试集上进行测试了

[19]:
y_pred = rnd_search_cv.best_estimator_.predict(X_test_scaled)
accuracy_score(y_test, y_pred)
[19]:
0.9733

看起来还不错,但是在训练集的acc要大于测试集的acc,明显模型有点过拟合了。我们快成尝试减小超参数(gamma和C),来降低正则化。但是这也会带来另外一个风险,就是在测试集上过拟合。openml上有帖子说C=0.5gamma=0.005可以达到较好的表现(acc在98%以上)。通过在测试集上长时间的随机搜索,你可以找到最佳的参数