{ "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4-final" }, "orig_nbformat": 2, "kernelspec": { "name": "Python 3.7.4 64-bit ('venv')", "display_name": "Python 3.7.4 64-bit ('venv')", "metadata": { "interpreter": { "hash": "e284c72d79b42194b3fe2a0767ff9cca6d233ae03063bab113c99e4bc6bd25a8" } } } }, "nbformat": 4, "nbformat_minor": 2, "cells": [ { "source": [ "# 练习 3-3\n", "处理Kaggle上的泰坦尼克数据集" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'./datasets/titanic'" ] }, "metadata": {}, "execution_count": 1 } ], "source": [ "import os\n", "TITANTIC_PATH = os.path.join('./datasets', 'titanic')\n", "TITANTIC_PATH" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "def load_titantic_data(filename, titantic_path=TITANTIC_PATH):\n", " filepath = os.path.join(titantic_path, filename)\n", " return pd.read_csv(filepath)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "titanic.zip: Skipping, found more recently modified local copy (use --force to force download)\n" ] } ], "source": [ "# pip install kaggle\n", "# 使用kaggle提供的api下载数据\n", "!kaggle competitions download -c titanic -p datasets" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_data = load_titantic_data('train.csv')\n", "test_data = load_titantic_data('test.csv')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " PassengerId Survived Pclass \\\n", "0 1 0 3 \n", "1 2 1 1 \n", "2 3 1 3 \n", "3 4 1 1 \n", "4 5 0 3 \n", "\n", " Name Sex Age SibSp \\\n", "0 Braund, Mr. Owen Harris male 22.0 1 \n", "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", "2 Heikkinen, Miss. Laina female 26.0 0 \n", "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", "4 Allen, Mr. William Henry male 35.0 0 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 A/5 21171 7.2500 NaN S \n", "1 0 PC 17599 71.2833 C85 C \n", "2 0 STON/O2. 3101282 7.9250 NaN S \n", "3 0 113803 53.1000 C123 S \n", "4 0 373450 8.0500 NaN S " ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n
" }, "metadata": {}, "execution_count": 5 } ], "source": [ "train_data.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " PassengerId Pclass Name Sex \\\n", "0 892 3 Kelly, Mr. James male \n", "1 893 3 Wilkes, Mrs. James (Ellen Needs) female \n", "2 894 2 Myles, Mr. Thomas Francis male \n", "3 895 3 Wirz, Mr. Albert male \n", "4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female \n", "\n", " Age SibSp Parch Ticket Fare Cabin Embarked \n", "0 34.5 0 0 330911 7.8292 NaN Q \n", "1 47.0 1 0 363272 7.0000 NaN S \n", "2 62.0 0 0 240276 9.6875 NaN Q \n", "3 27.0 0 0 315154 8.6625 NaN S \n", "4 22.0 1 1 3101298 12.2875 NaN S " ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
PassengerIdPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
08923Kelly, Mr. Jamesmale34.5003309117.8292NaNQ
18933Wilkes, Mrs. James (Ellen Needs)female47.0103632727.0000NaNS
28942Myles, Mr. Thomas Francismale62.0002402769.6875NaNQ
38953Wirz, Mr. Albertmale27.0003151548.6625NaNS
48963Hirvonen, Mrs. Alexander (Helga E Lindqvist)female22.011310129812.2875NaNS
\n
" }, "metadata": {}, "execution_count": 6 } ], "source": [ "test_data.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [] }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\nRangeIndex: 891 entries, 0 to 890\nData columns (total 12 columns):\n # Column Non-Null Count Dtype \n--- ------ -------------- ----- \n 0 PassengerId 891 non-null int64 \n 1 Survived 891 non-null int64 \n 2 Pclass 891 non-null int64 \n 3 Name 891 non-null object \n 4 Sex 891 non-null object \n 5 Age 714 non-null float64\n 6 SibSp 891 non-null int64 \n 7 Parch 891 non-null int64 \n 8 Ticket 891 non-null object \n 9 Fare 891 non-null float64\n 10 Cabin 204 non-null object \n 11 Embarked 889 non-null object \ndtypes: float64(2), int64(5), object(5)\nmemory usage: 83.7+ KB\n" ] } ], "source": [ "train_data.info()" ] }, { "source": [ "可以看出**Age, Cabin, Embarked**数据是不完全的,特别是的Cabin缺失了74%的数据,没办法只能忽略掉Cabin记录了。Age的缺失的数据可以使用median来代替。\n", "\n", "而**Name、Ticket**可能有些数字,但是转化为模型可以使用的数字有些棘手,所以也暂时忽略掉相关的记录" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " PassengerId Survived Pclass Age SibSp \\\n", "count 891.000000 891.000000 891.000000 714.000000 891.000000 \n", "mean 446.000000 0.383838 2.308642 29.699118 0.523008 \n", "std 257.353842 0.486592 0.836071 14.526497 1.102743 \n", "min 1.000000 0.000000 1.000000 0.420000 0.000000 \n", "25% 223.500000 0.000000 2.000000 20.125000 0.000000 \n", "50% 446.000000 0.000000 3.000000 28.000000 0.000000 \n", "75% 668.500000 1.000000 3.000000 38.000000 1.000000 \n", "max 891.000000 1.000000 3.000000 80.000000 8.000000 \n", "\n", " Parch Fare \n", "count 891.000000 891.000000 \n", "mean 0.381594 32.204208 \n", "std 0.806057 49.693429 \n", "min 0.000000 0.000000 \n", "25% 0.000000 7.910400 \n", "50% 0.000000 14.454200 \n", "75% 0.000000 31.000000 \n", "max 6.000000 512.329200 " ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
PassengerIdSurvivedPclassAgeSibSpParchFare
count891.000000891.000000891.000000714.000000891.000000891.000000891.000000
mean446.0000000.3838382.30864229.6991180.5230080.38159432.204208
std257.3538420.4865920.83607114.5264971.1027430.80605749.693429
min1.0000000.0000001.0000000.4200000.0000000.0000000.000000
25%223.5000000.0000002.00000020.1250000.0000000.0000007.910400
50%446.0000000.0000003.00000028.0000000.0000000.00000014.454200
75%668.5000001.0000003.00000038.0000001.0000000.00000031.000000
max891.0000001.0000003.00000080.0000008.0000006.000000512.329200
\n
" }, "metadata": {}, "execution_count": 8 } ], "source": [ "train_data.describe()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0 549\n", "1 342\n", "Name: Survived, dtype: int64" ] }, "metadata": {}, "execution_count": 9 } ], "source": [ "train_data['Survived'].value_counts()" ] }, { "source": [ "- 只有38.34%的人活了下来" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "3 491\n", "1 216\n", "2 184\n", "Name: Pclass, dtype: int64" ] }, "metadata": {}, "execution_count": 10 } ], "source": [ "train_data['Pclass'].value_counts()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "male 577\n", "female 314\n", "Name: Sex, dtype: int64" ] }, "metadata": {}, "execution_count": 11 } ], "source": [ "train_data['Sex'].value_counts()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "S 644\n", "C 168\n", "Q 77\n", "Name: Embarked, dtype: int64" ] }, "metadata": {}, "execution_count": 12 } ], "source": [ "train_data['Embarked'].value_counts()" ] }, { "source": [ "创建一个预处理Pipeline" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from sklearn.base import BaseEstimator, TransformerMixin\n", "\n", "class DataFrameSelector(BaseEstimator, TransformerMixin):\n", " def __init__(self, attribute_names):\n", " self.attribute_names = attribute_names\n", "\n", " def fit(self, X, y=None):\n", " return self\n", "\n", " def transform(self, X):\n", " return X[self.attribute_names]" ] }, { "source": [ "创建一个pipeline选出数值属性" ], "cell_type": "markdown", "metadata": {} }, { "source": [ "from sklearn.pipeline import Pipeline\n", "try:\n", " from sklearn.impute import SimpleImputer # scikit-learn 0.20+\n", "except ImportError:\n", " from sklearn.preprocessing import Imputer as SimpleImputer\n", "num_pipeline = Pipeline([\n", " (\"select_numeric\", DataFrameSelector([\"Age\", \"SibSp\", \"Parch\", \"Fare\"])),\n", " (\"imputer\", SimpleImputer(strategy='median'))\n", "])" ], "cell_type": "code", "metadata": {}, "execution_count": 14, "outputs": [] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[22. , 1. , 0. , 7.25 ],\n", " [38. , 1. , 0. , 71.2833],\n", " [26. , 0. , 0. , 7.925 ],\n", " ...,\n", " [28. , 1. , 2. , 23.45 ],\n", " [26. , 0. , 0. , 30. ],\n", " [32. , 0. , 0. , 7.75 ]])" ] }, "metadata": {}, "execution_count": 15 } ], "source": [ "num_pipeline.fit_transform(train_data)" ] }, { "source": [ "在0.20以前版本的Scikit-Learn需要使用`LabelBinarizer`或`CategoricalEncoder`才能将分类的值转化为one-hot-vector; 0.20+以上的版本可以直接使用`OneHotEncoder`类" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "PassengerId 891\n", "Survived 0\n", "Pclass 3\n", "Name Allen, Miss. Elisabeth Walton\n", "Sex male\n", "Age 24\n", "SibSp 0\n", "Parch 0\n", "Ticket CA. 2343\n", "Fare 8.05\n", "Cabin G6\n", "Embarked S\n", "dtype: object" ] }, "metadata": {}, "execution_count": 16 } ], "source": [ "most_ = pd.Series([train_data[c].value_counts().index[0] for c in train_data], index=train_data.columns)\n", "most_" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "PassengerId\nSurvived\nPclass\nName\nSex\nAge\nSibSp\nParch\nTicket\nFare\nCabin\nEmbarked\n" ] } ], "source": [ "for c in train_data:\n", " print(c)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'Allen, Miss. Elisabeth Walton'" ] }, "metadata": {}, "execution_count": 18 } ], "source": [ "train_data['Name'].value_counts().index[0]" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "class MostFrequentImputer(BaseEstimator, TransformerMixin):\n", " \"\"\"\n", " 计算分类中的值出现的次数最多的的类表是多少\n", " 示例:\n", " 如tain_data['Set'].value_counts():\n", " male 577\n", " female 314\n", " Name: Sex, dtype: int64\n", "\n", " 则ain_data['Set'].value_counts().index[0]:male\n", " \"\"\"\n", " def fit(self, X, y=None):\n", " self.most_requent_ = pd.Series([X[c].value_counts().index[0] for c in X], index=X.columns)\n", " return self\n", " \n", " def transform(self, X, y=None):\n", " return X.fillna(self.most_requent_)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import OneHotEncoder\n" ] }, { "source": [ "现在创建一个pipeline来处理分类属性" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "cat_pipeline = Pipeline([\n", " ('select_cat', DataFrameSelector(['Pclass', 'Sex', 'Embarked'])),\n", " ('imputer', MostFrequentImputer()),\n", " ('cat_encoder', OneHotEncoder(sparse=False))\n", "])" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[0., 0., 1., ..., 0., 0., 1.],\n", " [1., 0., 0., ..., 1., 0., 0.],\n", " [0., 0., 1., ..., 0., 0., 1.],\n", " ...,\n", " [0., 0., 1., ..., 0., 0., 1.],\n", " [1., 0., 0., ..., 1., 0., 0.],\n", " [0., 0., 1., ..., 0., 1., 0.]])" ] }, "metadata": {}, "execution_count": 22 } ], "source": [ "cat_pipeline.fit_transform(train_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data.head()" ] }, { "source": [ "**将数字特征和分类特征结合起来**" ], "cell_type": "markdown", "metadata": {} }, { "source": [ "现在可是使用这个preprocess_pipeline将raw data转化为机器学习模型使用的数据了" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import FeatureUnion\n", "preprocess_pipeline = FeatureUnion(transformer_list=[\n", " ('num_pipeline', num_pipeline),\n", " ('cat_pipeline', cat_pipeline),\n", "])" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[22., 1., 0., ..., 0., 0., 1.],\n", " [38., 1., 0., ..., 1., 0., 0.],\n", " [26., 0., 0., ..., 0., 0., 1.],\n", " ...,\n", " [28., 1., 2., ..., 0., 0., 1.],\n", " [26., 0., 0., ..., 1., 0., 0.],\n", " [32., 0., 0., ..., 0., 1., 0.]])" ] }, "metadata": {}, "execution_count": 25 } ], "source": [ "X_train = preprocess_pipeline.fit_transform(train_data)\n", "X_train" ] }, { "source": [ "**不要忘了训练的标签数据**" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0 0\n", "1 1\n", "2 1\n", "3 1\n", "4 0\n", " ..\n", "886 0\n", "887 1\n", "888 0\n", "889 1\n", "890 0\n", "Name: Survived, Length: 891, dtype: int64" ] }, "metadata": {}, "execution_count": 26 } ], "source": [ "y_train = train_data['Survived']\n", "y_train" ] }, { "source": [ "### 首先使用SVC模型测试一下" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "SVC(gamma='auto')" ] }, "metadata": {}, "execution_count": 27 } ], "source": [ "from sklearn.svm import SVC\n", "svm_clf = SVC(gamma='auto')\n", "svm_clf.fit(X_train, y_train)" ] }, { "source": [ "使用训练好的SVC模型进行预测" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1,\n", " 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1,\n", " 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0,\n", " 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1,\n", " 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n", " 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1,\n", " 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1,\n", " 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0,\n", " 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,\n", " 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,\n", " 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0,\n", " 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1,\n", " 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0,\n", " 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1,\n", " 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1,\n", " 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0])" ] }, "metadata": {}, "execution_count": 28 } ], "source": [ "X_test = preprocess_pipeline.fit_transform(test_data)\n", "y_pred = svm_clf.predict(X_test)\n", "y_pred" ] }, { "source": [ "此时我们可是使用SVC预测的结果按照Kaggle要求的格式构建号CSV文件,上传Kaggle看我们的得分,不过在此之前我们可是使用交叉验证的方法来看看我们的模型表现如何" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0.66666667, 0.66292135, 0.71910112, 0.74157303, 0.76404494,\n", " 0.71910112, 0.7752809 , 0.73033708, 0.74157303, 0.80898876])" ] }, "metadata": {}, "execution_count": 29 } ], "source": [ "from sklearn.model_selection import cross_val_score\n", "svm_scores = cross_val_score(svm_clf, X_train, y_train, cv=10)\n", "svm_scores" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.7329588014981274" ] }, "metadata": {}, "execution_count": 30 } ], "source": [ "svm_scores.mean()" ] }, { "source": [ "也就是说模型的accuracy只有73.30%,虽然明显要比随你乱猜要好,但是仍然不是一个好的得分。从kaggle的[leaderboard](https://www.kaggle.com/c/titanic/leaderboard)可以看到前面排名几乎都达到了100%, 不过由于可以下载的到[测试集](https://www.encyclopedia-titanica.org/titanic-victims/),100%的成绩中有比较大的水分,我们无需理会这些" ], "cell_type": "markdown", "metadata": {} }, { "source": [ "### 试试RandomForestClassifier" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0.74444444, 0.79775281, 0.75280899, 0.80898876, 0.88764045,\n", " 0.83146067, 0.83146067, 0.7752809 , 0.85393258, 0.84269663])" ] }, "metadata": {}, "execution_count": 31 } ], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n", "forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)\n", "forest_scores\n" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.8126466916354558" ] }, "metadata": {}, "execution_count": 32 } ], "source": [ "forest_scores.mean()" ] }, { "source": [ "可以看到RandomForestClassifier明显好了一下" ], "cell_type": "markdown", "metadata": {} }, { "source": [ "### 试试AdaBoostClassifier" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.812621722846442" ] }, "metadata": {}, "execution_count": 33 } ], "source": [ "from sklearn.ensemble import AdaBoostClassifier\n", "boost_clf = AdaBoostClassifier(n_estimators=100, random_state=42)\n", "boost_scores = cross_val_score(boost_clf, X_train, y_train, cv=10)\n", "boost_scores.mean()" ] }, { "source": [ "### 试试KNeighborsClassifier" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.7172659176029963" ] }, "metadata": {}, "execution_count": 34 } ], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "knn_clf = KNeighborsClassifier()\n", "knn_scores = cross_val_score(knn_clf, X_train, y_train, cv=10)\n", "knn_scores.mean()" ] }, { "source": [ "可以看到`RandomForestClassifier`表现的还是不错的,我们来找一下随机森林的最佳参数" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV\n", "params_grid = [{\n", " 'n_estimators':[10,20,30, 40, 50],\n", " 'criterion':['gini','entropy'],\n", " 'max_features':['sqrt','log2']}]\n", "grid_search = GridSearchCV(forest_clf, params_grid, cv=5, verbose=3, n_jobs=-1)\n", "\n" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Fitting 5 folds for each of 20 candidates, totalling 100 fits\n", "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", "[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 3.6s\n", "[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 5.8s finished\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "GridSearchCV(cv=5, estimator=RandomForestClassifier(random_state=42), n_jobs=-1,\n", " param_grid=[{'criterion': ['gini', 'entropy'],\n", " 'max_features': ['sqrt', 'log2'],\n", " 'n_estimators': [10, 20, 30, 40, 50]}],\n", " verbose=3)" ] }, "metadata": {}, "execution_count": 36 } ], "source": [ "grid_search.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'criterion': 'entropy', 'max_features': 'sqrt', 'n_estimators': 50}" ] }, "metadata": {}, "execution_count": 37 } ], "source": [ "grid_search.best_params_" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.8182397003745316" ] }, "metadata": {}, "execution_count": 39 } ], "source": [ "forest_clf = RandomForestClassifier(**grid_search.best_params_)\n", "forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)\n", "forest_scores.mean()" ] }, { "source": [ "### 实时BaggingClassifier" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(0.8272034956304619, 0.7498901958778095)" ] }, "metadata": {}, "execution_count": 46 } ], "source": [ "from sklearn.ensemble import BaggingClassifier\n", "bagging_clf = BaggingClassifier(forest_clf)\n", "bagging_clf_acc_sorces = cross_val_score(bagging_clf, X_train, y_train, cv=10, scoring='accuracy')\n", "bagging_clf_f1_sorces = cross_val_score(bagging_clf, X_train, y_train, cv=10, scoring='f1')\n", "bagging_clf_acc_sorces.mean(), bagging_clf_f1_sorces.mean()" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "bagging_grid_params=[{\n", " 'base_estimator':[svm_clf, forest_clf],\n", " 'n_estimators':[10,20,30, 40, 50],\n", " 'max_samples':[0.1, 0.3, 0.5, 0.8, 1.0],\n", " 'max_features':[0.1, 0.3, 0.5, 0.8, 1.0]\n", "}]\n", "bagging_grid_search = GridSearchCV(bagging_clf, bagging_grid_params, cv=5, verbose=3, n_jobs=-1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bagging_grid_search.fit(X_train, y_train)" ] }, { "source": [ "这里我们不关注每个模型的10折叠的平均分,而是看一下每个模型的每次折叠的箱线图" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\n\n\n\n \n \n \n \n 2020-10-11T22:43:16.300560\n image/svg+xml\n \n \n Matplotlib v3.3.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfwAAAD4CAYAAAAJtFSxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZOklEQVR4nO3df7RdZX3n8feHhBSnCk2AWgfyA1pU0Lqw3IlxWfxRf6HLQrXTNpixxJZhOSPWWrum2NqBolaxP1i00lrKuKAURKq1po6OtYJtWgnk3oIgZKExJSTU1RWStFarhiTf+WPv2xyuN3CSe+65N2e/X2uddfZ+9o/z9ZrL5z7Pec5zUlVIkqTRdtRcFyBJkmafgS9JUgcY+JIkdYCBL0lSBxj4kiR1wMK5LmC2nHDCCbVixYq5LkOSpKGZmJh4pKpOnO7YyAb+ihUrGB8fn+syJEkamiRbD3bMIX1JkjrAwJckqQMMfEmSOsDAlySpA4Ya+EnOSfJAks1JLpnm+PIkn0tyT5LPJzm559gFSb7SPi4YZt2SJB3phhb4SRYAVwOvAs4Azk9yxpTTfhv4k6p6DnA58N722iXApcDzgJXApUkWD6t2SZKOdMPs4a8ENlfVlqraA9wMnDflnDOAW9vt23qOvxL4bFXtqqrdwGeBc4ZQsyR1x7Y7Yf3vNM8aOcP8HP5JwLae/e00PfZeXwReB1wFvBZ4SpLjD3LtSVNfIMlFwEUAy5YtG1jhkjTytt0J158L+/bAgkVwwTpYunKuq9IAzbeFd34Z+ECStcDfAg8D+/q9uKquAa4BGBsbq9koUJKOZEn6O/HXp/bHHqvK/8QeaYY5pP8wsLRn/+S27T9U1T9V1euq6rnAr7Vt/9LPtZKkJ1ZV0z8euoN611Obc9711Gb/YOca9kekYQb+RuC0JKckWQSsBtb1npDkhCSTNb0D+FC7/RngFUkWt5P1XtG2SZIGYenKZhgfHM4fUUML/KraC1xME9SbgFuq6r4klyc5tz3txcADSb4MPBV4T3vtLuBdNH80bAQub9skSYMyGfKG/UjKqA7NjI2NlV+eI0mHJolD9kewJBNVNTbdMVfakySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SVJj252PfdZIMfAlSU3IX39us339uYb+CBpq4Cc5J8kDSTYnuWSa48uS3JbkriT3JHl1274iybeS3N0+PjjMuiVp5D24Hvbtabb37Wn2NVIWDuuFkiwArgZeDmwHNiZZV1X395z2TuCWqvrDJGcAnwJWtMe+WlVnDqteSeqUFWfDgkXN9oJFzb5GyjB7+CuBzVW1par2ADcD5005p4Bj2+3jgH8aYn2S1F1LV8IF65rtC9Y1+xopwwz8k4BtPfvb27ZelwH/Lcl2mt79W3qOndIO9f9Nkmn/9ExyUZLxJOM7duwYYOmS1AGTIW/Yj6T5NmnvfOC6qjoZeDVwQ5KjgK8By6rqucAvATclOXbqxVV1TVWNVdXYiSeeONTCJUmaz4YZ+A8DS3v2T27bev08cAtAVd0OHAOcUFXfqaqdbfsE8FXg6bNesSRJI2KYgb8ROC3JKUkWAauBdVPOeQh4KUCS02kCf0eSE9tJfyQ5FTgN2DK0yiVJOsINLfCrai9wMfAZYBPNbPz7klyepP3wJ28H/nuSLwIfBtZWVQEvBO5JcjfwUeBNVbVrWLVLUieMX/fYZ42UoX0sD6CqPkUzGa+37X/3bN8PvGCa6z4GfGzWC5Skrhq/Dj751mZ78nls7VxVo1mQpgM9esbGxmp8fHyuy5CkoVqyZAm7d++e0xoWL17Mrl0Ows6FJBNVNTbdsaH28CVJs2v37t0cVkeut4cP8JqrDruHn+SwrtPsMvAlSQfCfdMn4PTzHM4fQQa+JKkxttagH2HzbeEdSZI0Cwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SVJj/Dq44bXNs0aOX48rSWpC/pNvbba/emvz7FfljpRU1VzXMCvGxsZqfHx8rsuQpOG67Li5rqBx2b/OdQWdlGSiqsamO2YPX5JGSH7j6xxWR663hw/wmqsOu4efhLrssC7VLDLwJUkHwn3TJ+D08xzOH0EGviSpMbbWoB9hfc3ST/ITSRbMdjGSJGl29PuxvBuBh5NckeTps1mQJEkavH4D/weAS4EXAZuS/F2SNyb53tkrTZIkDUpfgV9V/1ZVf1RVq4DnAHcA7wW+luSPk6yazSIlSdLMHPJKe1V1H3AlcA2wCPgZYH2SO5I8Z8D1qYMmtu7m6ts2M7F191yXInXLtjth/e80zxo5fc/ST3I08Frg54CX0vTy3wR8BFgM/Ga7ffrgy1RXTGzdzZprN7Bn734WLTyKGy9cxVnLF891WdLo23YnXH8u7NsDCxbBBetg6cq5rkoD1FfgJ/l94HyggBuAX6qq+3tO+VaSS4B/GnyJGkVJ+jpv7N0HPzaqq0RKc+LB9U3Y177m+cH1Bv6I6XdI/wzgYuCkqpoa9pMeAV4ysMo00qpq2sf4g7t4xjs/BcAz3vkpxh/cddBzJQ3QirObnn0WNM8rzp7rijRgrqWveWdi627GVixh/MFdDudLhyjJ4f9BvO3Opme/4uwZ9e5nVINmZMZr6Sd5D7Ctqj44pf1NNL3+X595mVJjMuQNe2nIlq50GH+E9Tuk/wbgrmnaJ4Cf7ffFkpyT5IEkm9v3/KceX5bktiR3Jbknyat7jr2jve6BJK/s9zUlSVL/s/S/H9gxTftO4Kn93KBdmvdq4OXAdmBjknVT5gO8E7ilqv4wyRnAp4AV7fZq4FnAfwb+OsnTq2pfn/VLktRp/fbwHwKmm8HxQprw7sdKYHNVbamqPcDNwHlTzing2Hb7OA7M+j8PuLmqvlNV/whsbu8nSZL60G8P/4+AK5MsAm5t215Ks9reFX3e4yRgW8/+duB5U865DPirJG8Bvhd4Wc+1G6Zce9LUF0hyEXARwLJly/osS5Kk0ddX4FfV7yQ5Afg9mtX1APYAV1XV+wdYz/nAde3rPR+4Icmz+724qq6hWQGQsbExp4hKktTqe6W9qnpHknfTfCYfYFNVfeMQXuthYGnP/sltW6+fB85pX+/2JMcAJ/R5rSRJOohDWku/qr5ZVRvbx6GEPcBG4LQkp7RvDawG1k055yGatwpIcjpwDM1kwXXA6iTfk+QU4DTAxZ4lSerToayl/xKaIfdlHBjWB6CqfuyJrq+qvUkuBj4DLAA+VFX3JbkcGK+qdcDbgT9O8jaaCXxrq1m94b4ktwD3A3uBNztDX5Kk/vW10l6StcAHgY/TfIHOJ4CnA6cAf1pVF89ijYfFlfaObK7UJR2e+fC7Mx9q6KrHW2mv3yH9XwYurqrzgUeBd1TVc4E/BQ51aF+SJA1Zv4F/KvDX7fZ3gCe32x8A1g64JkmSNGD9Bv5O4Cnt9sPA5EfljgeeNOiiJEnSYPU7aW898ArgXuAW4PeSvJxmRv1nZ6k2SZI0IP0G/sU0H5GDZnW9vcALaML/3bNQlyRJGqAnDPwkC2k+M/8XAFW1n/6X05UkSfPAE76HX1V7gd8Cjp79ciRJ0mzod9LeBuCs2SxEkiTNnn7fw/9j4LeTLAMmgG/2Hqyqfxh0YZIkaXD6Dfyb2uffneZY0SyVK0mS5ql+A/+UWa1CkiTNqr4Cv6q2znYhkiRp9vQV+Ele93jHq+rPB1OOJEmaDf0O6X/0IO2TX4fke/gamImtu//j+azli+e4GunIk2ROX3/xYn9v56O+PpZXVUf1PoBFwPNoltx94WwWqG6Z2LqbNdduAGDNtRv+I/wl9aeqZvQYxD127do1xz8FTaffz+E/RlXtraqNwK8CfzDYktRlG7bsZM/e/QA8unc/G7bsnOOKJGk0HFbg9/gX4AcHUIcEwKpTj2fRwuaf5dELj2LVqcfPcUWSNBoyOYTzuCclPzK1CXga8CsAVXX24EubmbGxsRofH5/rMnQYJrbuZmzFEsYf3OV7+NKQJaGfXND8lGSiqsamO9bvpL1xmgl6U2eCbADeOIPapO8yGfKGvSQNzuEuvLMf2FFV3x5wPZIkaRa48I4kqbHtzgPPS1fObS0auL4m7SV5T5I3TdP+piTvGnxZkqSh2nYnXH9us339uQfCXyOj31n6bwDumqZ9AvjZwZUjPXbhHUlD8uB62Nu+S7v3282+Rkq/gf/9wI5p2ncCTx1cOeo6F96R5si3v86BxVOr3dco6XfS3kPA2cCWKe0vBLYPtCJ12nQL7zhbXxqcfpbdzW98Hbi8fUzPj+4defoN/D8CrkyyCLi1bXsp8F7gitkoTN3kwjvS7DpoUI9fB59864H911wFY2uHUZKGpK+FdwCSvBf4RZp19AH2AFdV1SWzU9rMuPDOkcuFd6Q5Mn4dbPoEnH6eYX+EeryFd/oO/PZG3wuc0e5uqqpvDKC+WWHgH9lc7UuSDt2MV9pL8gPAwqraDmzsaT8ZeLSq/nkglUqSpFnR7yz9PwVeNU37K4EbBleORsGSJUtIMqMHMON7LFmyZI5/EpI0f/Q7aW8MePM07euB3xpcORoFu3fvnhfD8f3MRpakrui3h78Q+J5p2o85SLskSZpH+g38O4D/MU37m+l5T1+SJM1P/Q7p/xpwa5LncOBz+D8G/AjN5/H7kuQc4CpgAXBtVb1vyvErgZe0u/8J+P6q+r722D7g3vbYQ1V1br+vK0lS1/X7bXkbkjwf+F/A69rmfwD+J3BiP/dIsgC4Gng5zep8G5Osq6r7e17nbT3nvwV4bs8tvlVVZ/bzWpIk6bH6HdKnqr5YVWuq6lk0s/O/DHwc+Eyft1gJbK6qLVW1B7gZOO9xzj8f+HC/9UmSpIPrO/CTLEjyuiT/F/hH4CeADwI/1OctTgK29exvb9ume63lwCkcePsA4Jgk40k2JPmJg1x3UXvO+I4d033XjyRJ3fSEQ/pJngFcSPM1uN8EbqLp4b+hdzh+wFYDH62qfT1ty6vq4SSn0swnuLeqvtp7UVVdA1wDzUp7s1SbJElHnMft4SdZD2wAFgM/XVWnVtU7OfAdiofiYWBpz/7Jbdt0VjNlOL+qHm6ftwCf57Hv70uSpMfxREP6zwf+BLiyqv5mhq+1ETgtySntt+6tBtZNPSnJM2n+wLi9p21xku9pt08AXgDM1uiCJEkj54kC/7/QDPv/XZK7krytXVf/kFXVXuBimkl+m4Bbquq+JJcn6f2I3Wrg5nrsUm2nA+NJvgjcBrxvFt9OkCRp5PT1bXlJjgF+Cvg54Edp/lC4hOaz9LtntcLD5LflzZ2ZftPdTXc8xKe/9DVe9eyn8frnLZuzOiTpSDPjb8urqm/TfEnODUl+iGYS39uAdye5taqm+2Id6ZDddMdD/OrHm/WV1n/lEYAZhb4kqdH3x/ImVdXmqrqEZgLeTwN7Bl6VOuvTX/ra4+5Lkg7PIQf+pKraV1WfqKrHWzxHOiSvevbTHndfknR4+l1LXxqKyeH7QbyHL0k6wMDXvPP65y0z6CVpwA57SF+SJB057OFr4OrSY+Gy4+a6jKYOSRJg4GsW5De+Pi8+/56Eumyuq5Ck+cEhfc07E1t3c/Vtm5nYOi/XdJKkI5I9fM0rE1t3s+baDezZu59FC4/ixgtXcdbyxXNdliQd8ezha17ZsGUne/buZ3/Bo3v3s2HLzrkuSZJGgoGveWXVqcezaOFRLAgcvfAoVp16/FyXJEkjwSF9zStnLV/MjReuYsOWnaw69XiH8yVpQAx8zTtnLV9s0EvSgDmkL0lSBxj4kiR1gIEvSVIHGPiSJHWAga9556Y7HuIN/+cObrrjobkuRZJGhrP0Na/cdMdD/OrH7wVg/VceAfCrciVpAAx8zYokA7nPmitgzWFeu3ixH+2TpEkO6WvgquqwHzdu2MryX/kkAMt/5ZPcuGHrYd9r165dc/yTkKT5wx6+5pXJ4fs1V8BvvvaHHc6XpAGxh695ZzLkDXtJGhwDX5KkDjDwJUnqAANfkqQOMPAlSeoAA1+SpA4w8CVJ6gADX5KkDjDwJUnqAANfkqQOGGrgJzknyQNJNie5ZJrjVya5u318Ocm/9By7IMlX2scFw6xbkqQj3dDW0k+yALgaeDmwHdiYZF1V3T95TlW9ref8twDPbbeXAJcCY0ABE+21u4dVvyRJR7Jh9vBXApuraktV7QFuBs57nPPPBz7cbr8S+GxV7WpD/rPAObNarSRJI2SYgX8SsK1nf3vb9l2SLAdOAW49lGuTXJRkPMn4jh07BlK0JEmjYL5O2lsNfLSq9h3KRVV1TVWNVdXYiSeeOEulSZJ05Blm4D8MLO3ZP7ltm85qDgznH+q1kiRpimEG/kbgtCSnJFlEE+rrpp6U5JnAYuD2nubPAK9IsjjJYuAVbZskSerD0GbpV9XeJBfTBPUC4ENVdV+Sy4HxqpoM/9XAzVVVPdfuSvIumj8aAC6vql3Dql2SpCNdenJ1pIyNjdX4+Phcl6HDlIRR/bcpSbMlyURVjU13bL5O2pMkSQNk4EuS1AEGviRJHWDgS5LUAQa+JEkdYOBLktQBBr4kSR1g4EuS1AEGviRJHWDgS5LUAQa+JEkdYOBLktQBBr4kSR1g4EuS1AEGviRJHWDgS5LUAQa+JEkdYOBLktQBBr4kSR1g4EuS1AEGvuadia27H/MsSZo5A1/zysTW3ay5dgMAa67dYOhL0oAY+JpXNmzZyZ69+wF4dO9+NmzZOccVSdJoMPA1r6w69XgWHhUAFhwVVp16/BxXJEmjwcDX/JM89lmSNGML57oAdVP6CPOvvOfVjL3n4MeraoAVSdJoM/A1Jw4W1pOT9h7du5+jFx7FjReu4qzli4dcnSSNHgNf88pZyxdz44Wr2LBlJ6tOPd6wl6QBMfA175y1fLFBL0kD5qQ9SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAzKqq5Ul2QFsnes6dNhOAB6Z6yKkDvJ378i2vKpOnO7AyAa+jmxJxqtqbK7rkLrG373R5ZC+JEkdYOBLktQBBr7mq2vmugCpo/zdG1G+hy9JUgfYw5ckqQMMfEmSOsDA19Al+bUk9yW5J8ndSS5N8t4p55yZZFO7/WCS9VOO353kS8OsW5oNSfZN/ntO8pdJvm9A912b5AODuNeU+34+yQNtzXcn+a+Dfo32dVYkef1s3LurDHwNVZLnA68BfqSqngO8DLgN+Jkpp64GPtyz/5QkS9t7nD6MWqUh+VZVnVlVzwZ2AW+e64L6sKat+cyq+mg/FyRZeIivsQIw8AfIwNewPQ14pKq+A1BVj1TV3wK7kzyv57yf5rGBfwsH/ig4f8oxaVTcDpwEkGRlktuT3JXkC0me0bavTfLnSf5fkq8kef/kxUnemOTLSe4EXtDTviLJre2o2ueSLGvbr0vyh0k2JNmS5MVJPpRkU5Lr+i06yZIkf9Hef0OS57TtlyW5IcnfAzckOTHJx5JsbB8vaM97Uc+IwV1JngK8Dzi7bXvbTH+wAqrKh4+hPYAnA3cDXwb+AHhR2/7LwJXt9ipgvOeaB4FnAF9o9+8CzgC+NNf/e3z4mOkD+Eb7vAD4M+Ccdv9YYGG7/TLgY+32WmALcBxwDM0S4ktp/ph+CDgRWAT8PfCB9pq/BC5ot38O+It2+zrgZiDAecDXgR+m6QxOAGdOU+/ngQfa3+O7geOB3wcubY//GHB3u31Ze58ntfs3AT/abi8DNvXU94J2+8nAQuDFwCfn+v+fUXoc6hCLNCNV9Y0kZwFnAy8BPpLkEuAjwBeSvJ3vHs4H2EkzCrAa2AT8+xDLlmbTk5LcTdOz3wR8tm0/Drg+yWlAAUf3XPO5qvpXgCT3A8tp1sD/fFXtaNs/Ajy9Pf/5wOva7RuA9/fc6y+rqpLcC/xzVd3bXn8fzbD63dPUvKaqxid3kvwo8JMAVXVrkuOTHNseXldV32q3XwackWTy0mOTPJnmj5PfTXIj8OdVtb3nHA2IQ/oauqraV1Wfr6pLgYuBn6yqbcA/Ai+i+Q/HR6a59CPA1Ticr9Hyrao6kya0w4H38N8F3FbNe/s/TtObn/Sdnu19MKPO2+S99k+57/4Z3nfSN3u2jwJW1YH3/0+qqm9U1fuAC4EnAX+f5JkDeF1NYeBrqJI8o+2xTDqTA99q+GHgSmBLVW2f5vKP0/RMPjOrRUpzoKr+HfgF4O3tBLfjgIfbw2v7uMUdwIva3vXRwE/1HPsCzcgZwBpg/dSLZ2h9e1+SvJhmns7Xpznvr4C3TO4kObN9/sGqureqrgA2As8E/g14yoDr7DQDX8P2ZJphyvuT3EPzXvxl7bE/A57FQXrwVfVvVXVFVe0ZSqXSkFXVXcA9NBNT3w+8N8ld9NHTrqqv0fwu3U4zRL6p5/BbgDe2v3NvAN462Mq5DDirvf/7gAsOct4vAGPt5L77gTe17b/YfizxHuBR4NM0P4d9Sb7opL3BcGldSZI6wB6+JEkdYOBLktQBBr4kSR1g4EuS1AEGviRJHWDgS5LUAQa+JEkd8P8BRwxIZRxqTYcAAAAASUVORK5CYII=\n" }, "metadata": { "needs_background": "light" } } ], "source": [ "plt.figure(figsize=(8, 4))\n", "plt.plot([1]*10, svm_scores, '.')\n", "plt.plot([2]*10, forest_scores, '.')\n", "plt.boxplot([svm_scores, forest_scores], labels=('SVM', 'Random Forest'))\n", "plt.ylabel('Accuracy', fontsize=14)\n", "plt.show()" ] }, { "source": [ "为了进一步改善结果,可以进行如下操作:\n", "- 对更多模型,使用cross validation和grid search调整超参数\n", "- 使用更多的特征工程,例如:\n", " - 是**SibSp**和**Parch**的和代替他们\n", " - 尝试找出与**Survived**属性很好相关的部分\n", "- 尝试把年龄属性更改为年龄段属性\n" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Survived\n", "AgeBucket \n", "0.0 0.576923\n", "15.0 0.362745\n", "30.0 0.423256\n", "45.0 0.404494\n", "60.0 0.240000\n", "75.0 1.000000" ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Survived
AgeBucket
0.00.576923
15.00.362745
30.00.423256
45.00.404494
60.00.240000
75.01.000000
\n
" }, "metadata": {}, "execution_count": 44 } ], "source": [ "train_data[\"AgeBucket\"] = train_data[\"Age\"] // 15 * 15\n", "train_data[[\"AgeBucket\", \"Survived\"]].groupby(['AgeBucket']).mean()" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Survived\n", "RelativesOnboard \n", "0 0.303538\n", "1 0.552795\n", "2 0.578431\n", "3 0.724138\n", "4 0.200000\n", "5 0.136364\n", "6 0.333333\n", "7 0.000000\n", "10 0.000000" ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Survived
RelativesOnboard
00.303538
10.552795
20.578431
30.724138
40.200000
50.136364
60.333333
70.000000
100.000000
\n
" }, "metadata": {}, "execution_count": 77 } ], "source": [ "train_data[\"RelativesOnboard\"] = train_data[\"SibSp\"] + train_data[\"Parch\"]\n", "train_data[[\"RelativesOnboard\", \"Survived\"]].groupby(['RelativesOnboard']).mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ] }