{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "Python 3.7.4 64-bit ('venv')",
"display_name": "Python 3.7.4 64-bit ('venv')",
"metadata": {
"interpreter": {
"hash": "e284c72d79b42194b3fe2a0767ff9cca6d233ae03063bab113c99e4bc6bd25a8"
}
}
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"# 练习 3-3\n",
"处理Kaggle上的泰坦尼克数据集"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'./datasets/titanic'"
]
},
"metadata": {},
"execution_count": 1
}
],
"source": [
"import os\n",
"TITANTIC_PATH = os.path.join('./datasets', 'titanic')\n",
"TITANTIC_PATH"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"def load_titantic_data(filename, titantic_path=TITANTIC_PATH):\n",
" filepath = os.path.join(titantic_path, filename)\n",
" return pd.read_csv(filepath)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"titanic.zip: Skipping, found more recently modified local copy (use --force to force download)\n"
]
}
],
"source": [
"# pip install kaggle\n",
"# 使用kaggle提供的api下载数据\n",
"!kaggle competitions download -c titanic -p datasets"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_data = load_titantic_data('train.csv')\n",
"test_data = load_titantic_data('test.csv')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" PassengerId Survived Pclass \\\n",
"0 1 0 3 \n",
"1 2 1 1 \n",
"2 3 1 3 \n",
"3 4 1 1 \n",
"4 5 0 3 \n",
"\n",
" Name Sex Age SibSp \\\n",
"0 Braund, Mr. Owen Harris male 22.0 1 \n",
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n",
"2 Heikkinen, Miss. Laina female 26.0 0 \n",
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n",
"4 Allen, Mr. William Henry male 35.0 0 \n",
"\n",
" Parch Ticket Fare Cabin Embarked \n",
"0 0 A/5 21171 7.2500 NaN S \n",
"1 0 PC 17599 71.2833 C85 C \n",
"2 0 STON/O2. 3101282 7.9250 NaN S \n",
"3 0 113803 53.1000 C123 S \n",
"4 0 373450 8.0500 NaN S "
],
"text/html": "
\n\n
\n \n \n | \n PassengerId | \n Survived | \n Pclass | \n Name | \n Sex | \n Age | \n SibSp | \n Parch | \n Ticket | \n Fare | \n Cabin | \n Embarked | \n
\n \n \n \n 0 | \n 1 | \n 0 | \n 3 | \n Braund, Mr. Owen Harris | \n male | \n 22.0 | \n 1 | \n 0 | \n A/5 21171 | \n 7.2500 | \n NaN | \n S | \n
\n \n 1 | \n 2 | \n 1 | \n 1 | \n Cumings, Mrs. John Bradley (Florence Briggs Th... | \n female | \n 38.0 | \n 1 | \n 0 | \n PC 17599 | \n 71.2833 | \n C85 | \n C | \n
\n \n 2 | \n 3 | \n 1 | \n 3 | \n Heikkinen, Miss. Laina | \n female | \n 26.0 | \n 0 | \n 0 | \n STON/O2. 3101282 | \n 7.9250 | \n NaN | \n S | \n
\n \n 3 | \n 4 | \n 1 | \n 1 | \n Futrelle, Mrs. Jacques Heath (Lily May Peel) | \n female | \n 35.0 | \n 1 | \n 0 | \n 113803 | \n 53.1000 | \n C123 | \n S | \n
\n \n 4 | \n 5 | \n 0 | \n 3 | \n Allen, Mr. William Henry | \n male | \n 35.0 | \n 0 | \n 0 | \n 373450 | \n 8.0500 | \n NaN | \n S | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 5
}
],
"source": [
"train_data.head()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" PassengerId Pclass Name Sex \\\n",
"0 892 3 Kelly, Mr. James male \n",
"1 893 3 Wilkes, Mrs. James (Ellen Needs) female \n",
"2 894 2 Myles, Mr. Thomas Francis male \n",
"3 895 3 Wirz, Mr. Albert male \n",
"4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female \n",
"\n",
" Age SibSp Parch Ticket Fare Cabin Embarked \n",
"0 34.5 0 0 330911 7.8292 NaN Q \n",
"1 47.0 1 0 363272 7.0000 NaN S \n",
"2 62.0 0 0 240276 9.6875 NaN Q \n",
"3 27.0 0 0 315154 8.6625 NaN S \n",
"4 22.0 1 1 3101298 12.2875 NaN S "
],
"text/html": "\n\n
\n \n \n | \n PassengerId | \n Pclass | \n Name | \n Sex | \n Age | \n SibSp | \n Parch | \n Ticket | \n Fare | \n Cabin | \n Embarked | \n
\n \n \n \n 0 | \n 892 | \n 3 | \n Kelly, Mr. James | \n male | \n 34.5 | \n 0 | \n 0 | \n 330911 | \n 7.8292 | \n NaN | \n Q | \n
\n \n 1 | \n 893 | \n 3 | \n Wilkes, Mrs. James (Ellen Needs) | \n female | \n 47.0 | \n 1 | \n 0 | \n 363272 | \n 7.0000 | \n NaN | \n S | \n
\n \n 2 | \n 894 | \n 2 | \n Myles, Mr. Thomas Francis | \n male | \n 62.0 | \n 0 | \n 0 | \n 240276 | \n 9.6875 | \n NaN | \n Q | \n
\n \n 3 | \n 895 | \n 3 | \n Wirz, Mr. Albert | \n male | \n 27.0 | \n 0 | \n 0 | \n 315154 | \n 8.6625 | \n NaN | \n S | \n
\n \n 4 | \n 896 | \n 3 | \n Hirvonen, Mrs. Alexander (Helga E Lindqvist) | \n female | \n 22.0 | \n 1 | \n 1 | \n 3101298 | \n 12.2875 | \n NaN | \n S | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 6
}
],
"source": [
"test_data.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\nRangeIndex: 891 entries, 0 to 890\nData columns (total 12 columns):\n # Column Non-Null Count Dtype \n--- ------ -------------- ----- \n 0 PassengerId 891 non-null int64 \n 1 Survived 891 non-null int64 \n 2 Pclass 891 non-null int64 \n 3 Name 891 non-null object \n 4 Sex 891 non-null object \n 5 Age 714 non-null float64\n 6 SibSp 891 non-null int64 \n 7 Parch 891 non-null int64 \n 8 Ticket 891 non-null object \n 9 Fare 891 non-null float64\n 10 Cabin 204 non-null object \n 11 Embarked 889 non-null object \ndtypes: float64(2), int64(5), object(5)\nmemory usage: 83.7+ KB\n"
]
}
],
"source": [
"train_data.info()"
]
},
{
"source": [
"可以看出**Age, Cabin, Embarked**数据是不完全的,特别是的Cabin缺失了74%的数据,没办法只能忽略掉Cabin记录了。Age的缺失的数据可以使用median来代替。\n",
"\n",
"而**Name、Ticket**可能有些数字,但是转化为模型可以使用的数字有些棘手,所以也暂时忽略掉相关的记录"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" PassengerId Survived Pclass Age SibSp \\\n",
"count 891.000000 891.000000 891.000000 714.000000 891.000000 \n",
"mean 446.000000 0.383838 2.308642 29.699118 0.523008 \n",
"std 257.353842 0.486592 0.836071 14.526497 1.102743 \n",
"min 1.000000 0.000000 1.000000 0.420000 0.000000 \n",
"25% 223.500000 0.000000 2.000000 20.125000 0.000000 \n",
"50% 446.000000 0.000000 3.000000 28.000000 0.000000 \n",
"75% 668.500000 1.000000 3.000000 38.000000 1.000000 \n",
"max 891.000000 1.000000 3.000000 80.000000 8.000000 \n",
"\n",
" Parch Fare \n",
"count 891.000000 891.000000 \n",
"mean 0.381594 32.204208 \n",
"std 0.806057 49.693429 \n",
"min 0.000000 0.000000 \n",
"25% 0.000000 7.910400 \n",
"50% 0.000000 14.454200 \n",
"75% 0.000000 31.000000 \n",
"max 6.000000 512.329200 "
],
"text/html": "\n\n
\n \n \n | \n PassengerId | \n Survived | \n Pclass | \n Age | \n SibSp | \n Parch | \n Fare | \n
\n \n \n \n count | \n 891.000000 | \n 891.000000 | \n 891.000000 | \n 714.000000 | \n 891.000000 | \n 891.000000 | \n 891.000000 | \n
\n \n mean | \n 446.000000 | \n 0.383838 | \n 2.308642 | \n 29.699118 | \n 0.523008 | \n 0.381594 | \n 32.204208 | \n
\n \n std | \n 257.353842 | \n 0.486592 | \n 0.836071 | \n 14.526497 | \n 1.102743 | \n 0.806057 | \n 49.693429 | \n
\n \n min | \n 1.000000 | \n 0.000000 | \n 1.000000 | \n 0.420000 | \n 0.000000 | \n 0.000000 | \n 0.000000 | \n
\n \n 25% | \n 223.500000 | \n 0.000000 | \n 2.000000 | \n 20.125000 | \n 0.000000 | \n 0.000000 | \n 7.910400 | \n
\n \n 50% | \n 446.000000 | \n 0.000000 | \n 3.000000 | \n 28.000000 | \n 0.000000 | \n 0.000000 | \n 14.454200 | \n
\n \n 75% | \n 668.500000 | \n 1.000000 | \n 3.000000 | \n 38.000000 | \n 1.000000 | \n 0.000000 | \n 31.000000 | \n
\n \n max | \n 891.000000 | \n 1.000000 | \n 3.000000 | \n 80.000000 | \n 8.000000 | \n 6.000000 | \n 512.329200 | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 8
}
],
"source": [
"train_data.describe()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0 549\n",
"1 342\n",
"Name: Survived, dtype: int64"
]
},
"metadata": {},
"execution_count": 9
}
],
"source": [
"train_data['Survived'].value_counts()"
]
},
{
"source": [
"- 只有38.34%的人活了下来"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"3 491\n",
"1 216\n",
"2 184\n",
"Name: Pclass, dtype: int64"
]
},
"metadata": {},
"execution_count": 10
}
],
"source": [
"train_data['Pclass'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"male 577\n",
"female 314\n",
"Name: Sex, dtype: int64"
]
},
"metadata": {},
"execution_count": 11
}
],
"source": [
"train_data['Sex'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"S 644\n",
"C 168\n",
"Q 77\n",
"Name: Embarked, dtype: int64"
]
},
"metadata": {},
"execution_count": 12
}
],
"source": [
"train_data['Embarked'].value_counts()"
]
},
{
"source": [
"创建一个预处理Pipeline"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"\n",
"class DataFrameSelector(BaseEstimator, TransformerMixin):\n",
" def __init__(self, attribute_names):\n",
" self.attribute_names = attribute_names\n",
"\n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X):\n",
" return X[self.attribute_names]"
]
},
{
"source": [
"创建一个pipeline选出数值属性"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"from sklearn.pipeline import Pipeline\n",
"try:\n",
" from sklearn.impute import SimpleImputer # scikit-learn 0.20+\n",
"except ImportError:\n",
" from sklearn.preprocessing import Imputer as SimpleImputer\n",
"num_pipeline = Pipeline([\n",
" (\"select_numeric\", DataFrameSelector([\"Age\", \"SibSp\", \"Parch\", \"Fare\"])),\n",
" (\"imputer\", SimpleImputer(strategy='median'))\n",
"])"
],
"cell_type": "code",
"metadata": {},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[22. , 1. , 0. , 7.25 ],\n",
" [38. , 1. , 0. , 71.2833],\n",
" [26. , 0. , 0. , 7.925 ],\n",
" ...,\n",
" [28. , 1. , 2. , 23.45 ],\n",
" [26. , 0. , 0. , 30. ],\n",
" [32. , 0. , 0. , 7.75 ]])"
]
},
"metadata": {},
"execution_count": 15
}
],
"source": [
"num_pipeline.fit_transform(train_data)"
]
},
{
"source": [
"在0.20以前版本的Scikit-Learn需要使用`LabelBinarizer`或`CategoricalEncoder`才能将分类的值转化为one-hot-vector; 0.20+以上的版本可以直接使用`OneHotEncoder`类"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"PassengerId 891\n",
"Survived 0\n",
"Pclass 3\n",
"Name Allen, Miss. Elisabeth Walton\n",
"Sex male\n",
"Age 24\n",
"SibSp 0\n",
"Parch 0\n",
"Ticket CA. 2343\n",
"Fare 8.05\n",
"Cabin G6\n",
"Embarked S\n",
"dtype: object"
]
},
"metadata": {},
"execution_count": 16
}
],
"source": [
"most_ = pd.Series([train_data[c].value_counts().index[0] for c in train_data], index=train_data.columns)\n",
"most_"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"PassengerId\nSurvived\nPclass\nName\nSex\nAge\nSibSp\nParch\nTicket\nFare\nCabin\nEmbarked\n"
]
}
],
"source": [
"for c in train_data:\n",
" print(c)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'Allen, Miss. Elisabeth Walton'"
]
},
"metadata": {},
"execution_count": 18
}
],
"source": [
"train_data['Name'].value_counts().index[0]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"class MostFrequentImputer(BaseEstimator, TransformerMixin):\n",
" \"\"\"\n",
" 计算分类中的值出现的次数最多的的类表是多少\n",
" 示例:\n",
" 如tain_data['Set'].value_counts():\n",
" male 577\n",
" female 314\n",
" Name: Sex, dtype: int64\n",
"\n",
" 则ain_data['Set'].value_counts().index[0]:male\n",
" \"\"\"\n",
" def fit(self, X, y=None):\n",
" self.most_requent_ = pd.Series([X[c].value_counts().index[0] for c in X], index=X.columns)\n",
" return self\n",
" \n",
" def transform(self, X, y=None):\n",
" return X.fillna(self.most_requent_)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import OneHotEncoder\n"
]
},
{
"source": [
"现在创建一个pipeline来处理分类属性"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"cat_pipeline = Pipeline([\n",
" ('select_cat', DataFrameSelector(['Pclass', 'Sex', 'Embarked'])),\n",
" ('imputer', MostFrequentImputer()),\n",
" ('cat_encoder', OneHotEncoder(sparse=False))\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[0., 0., 1., ..., 0., 0., 1.],\n",
" [1., 0., 0., ..., 1., 0., 0.],\n",
" [0., 0., 1., ..., 0., 0., 1.],\n",
" ...,\n",
" [0., 0., 1., ..., 0., 0., 1.],\n",
" [1., 0., 0., ..., 1., 0., 0.],\n",
" [0., 0., 1., ..., 0., 1., 0.]])"
]
},
"metadata": {},
"execution_count": 22
}
],
"source": [
"cat_pipeline.fit_transform(train_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data.head()"
]
},
{
"source": [
"**将数字特征和分类特征结合起来**"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"现在可是使用这个preprocess_pipeline将raw data转化为机器学习模型使用的数据了"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import FeatureUnion\n",
"preprocess_pipeline = FeatureUnion(transformer_list=[\n",
" ('num_pipeline', num_pipeline),\n",
" ('cat_pipeline', cat_pipeline),\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[22., 1., 0., ..., 0., 0., 1.],\n",
" [38., 1., 0., ..., 1., 0., 0.],\n",
" [26., 0., 0., ..., 0., 0., 1.],\n",
" ...,\n",
" [28., 1., 2., ..., 0., 0., 1.],\n",
" [26., 0., 0., ..., 1., 0., 0.],\n",
" [32., 0., 0., ..., 0., 1., 0.]])"
]
},
"metadata": {},
"execution_count": 25
}
],
"source": [
"X_train = preprocess_pipeline.fit_transform(train_data)\n",
"X_train"
]
},
{
"source": [
"**不要忘了训练的标签数据**"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0 0\n",
"1 1\n",
"2 1\n",
"3 1\n",
"4 0\n",
" ..\n",
"886 0\n",
"887 1\n",
"888 0\n",
"889 1\n",
"890 0\n",
"Name: Survived, Length: 891, dtype: int64"
]
},
"metadata": {},
"execution_count": 26
}
],
"source": [
"y_train = train_data['Survived']\n",
"y_train"
]
},
{
"source": [
"### 首先使用SVC模型测试一下"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"SVC(gamma='auto')"
]
},
"metadata": {},
"execution_count": 27
}
],
"source": [
"from sklearn.svm import SVC\n",
"svm_clf = SVC(gamma='auto')\n",
"svm_clf.fit(X_train, y_train)"
]
},
{
"source": [
"使用训练好的SVC模型进行预测"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1,\n",
" 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1,\n",
" 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0,\n",
" 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1,\n",
" 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n",
" 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1,\n",
" 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1,\n",
" 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0,\n",
" 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,\n",
" 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,\n",
" 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0,\n",
" 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1,\n",
" 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0,\n",
" 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1,\n",
" 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1,\n",
" 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0])"
]
},
"metadata": {},
"execution_count": 28
}
],
"source": [
"X_test = preprocess_pipeline.fit_transform(test_data)\n",
"y_pred = svm_clf.predict(X_test)\n",
"y_pred"
]
},
{
"source": [
"此时我们可是使用SVC预测的结果按照Kaggle要求的格式构建号CSV文件,上传Kaggle看我们的得分,不过在此之前我们可是使用交叉验证的方法来看看我们的模型表现如何"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0.66666667, 0.66292135, 0.71910112, 0.74157303, 0.76404494,\n",
" 0.71910112, 0.7752809 , 0.73033708, 0.74157303, 0.80898876])"
]
},
"metadata": {},
"execution_count": 29
}
],
"source": [
"from sklearn.model_selection import cross_val_score\n",
"svm_scores = cross_val_score(svm_clf, X_train, y_train, cv=10)\n",
"svm_scores"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.7329588014981274"
]
},
"metadata": {},
"execution_count": 30
}
],
"source": [
"svm_scores.mean()"
]
},
{
"source": [
"也就是说模型的accuracy只有73.30%,虽然明显要比随你乱猜要好,但是仍然不是一个好的得分。从kaggle的[leaderboard](https://www.kaggle.com/c/titanic/leaderboard)可以看到前面排名几乎都达到了100%, 不过由于可以下载的到[测试集](https://www.encyclopedia-titanica.org/titanic-victims/),100%的成绩中有比较大的水分,我们无需理会这些"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"### 试试RandomForestClassifier"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0.74444444, 0.79775281, 0.75280899, 0.80898876, 0.88764045,\n",
" 0.83146067, 0.83146067, 0.7752809 , 0.85393258, 0.84269663])"
]
},
"metadata": {},
"execution_count": 31
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)\n",
"forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)\n",
"forest_scores\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.8126466916354558"
]
},
"metadata": {},
"execution_count": 32
}
],
"source": [
"forest_scores.mean()"
]
},
{
"source": [
"可以看到RandomForestClassifier明显好了一下"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"### 试试AdaBoostClassifier"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.812621722846442"
]
},
"metadata": {},
"execution_count": 33
}
],
"source": [
"from sklearn.ensemble import AdaBoostClassifier\n",
"boost_clf = AdaBoostClassifier(n_estimators=100, random_state=42)\n",
"boost_scores = cross_val_score(boost_clf, X_train, y_train, cv=10)\n",
"boost_scores.mean()"
]
},
{
"source": [
"### 试试KNeighborsClassifier"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.7172659176029963"
]
},
"metadata": {},
"execution_count": 34
}
],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"knn_clf = KNeighborsClassifier()\n",
"knn_scores = cross_val_score(knn_clf, X_train, y_train, cv=10)\n",
"knn_scores.mean()"
]
},
{
"source": [
"可以看到`RandomForestClassifier`表现的还是不错的,我们来找一下随机森林的最佳参数"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"params_grid = [{\n",
" 'n_estimators':[10,20,30, 40, 50],\n",
" 'criterion':['gini','entropy'],\n",
" 'max_features':['sqrt','log2']}]\n",
"grid_search = GridSearchCV(forest_clf, params_grid, cv=5, verbose=3, n_jobs=-1)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Fitting 5 folds for each of 20 candidates, totalling 100 fits\n",
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
"[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 3.6s\n",
"[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed: 5.8s finished\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GridSearchCV(cv=5, estimator=RandomForestClassifier(random_state=42), n_jobs=-1,\n",
" param_grid=[{'criterion': ['gini', 'entropy'],\n",
" 'max_features': ['sqrt', 'log2'],\n",
" 'n_estimators': [10, 20, 30, 40, 50]}],\n",
" verbose=3)"
]
},
"metadata": {},
"execution_count": 36
}
],
"source": [
"grid_search.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'criterion': 'entropy', 'max_features': 'sqrt', 'n_estimators': 50}"
]
},
"metadata": {},
"execution_count": 37
}
],
"source": [
"grid_search.best_params_"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.8182397003745316"
]
},
"metadata": {},
"execution_count": 39
}
],
"source": [
"forest_clf = RandomForestClassifier(**grid_search.best_params_)\n",
"forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)\n",
"forest_scores.mean()"
]
},
{
"source": [
"### 实时BaggingClassifier"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.8272034956304619, 0.7498901958778095)"
]
},
"metadata": {},
"execution_count": 46
}
],
"source": [
"from sklearn.ensemble import BaggingClassifier\n",
"bagging_clf = BaggingClassifier(forest_clf)\n",
"bagging_clf_acc_sorces = cross_val_score(bagging_clf, X_train, y_train, cv=10, scoring='accuracy')\n",
"bagging_clf_f1_sorces = cross_val_score(bagging_clf, X_train, y_train, cv=10, scoring='f1')\n",
"bagging_clf_acc_sorces.mean(), bagging_clf_f1_sorces.mean()"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"bagging_grid_params=[{\n",
" 'base_estimator':[svm_clf, forest_clf],\n",
" 'n_estimators':[10,20,30, 40, 50],\n",
" 'max_samples':[0.1, 0.3, 0.5, 0.8, 1.0],\n",
" 'max_features':[0.1, 0.3, 0.5, 0.8, 1.0]\n",
"}]\n",
"bagging_grid_search = GridSearchCV(bagging_clf, bagging_grid_params, cv=5, verbose=3, n_jobs=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bagging_grid_search.fit(X_train, y_train)"
]
},
{
"source": [
"这里我们不关注每个模型的10折叠的平均分,而是看一下每个模型的每次折叠的箱线图"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "",
"image/svg+xml": "\n\n\n\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfwAAAD4CAYAAAAJtFSxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZOklEQVR4nO3df7RdZX3n8feHhBSnCk2AWgfyA1pU0Lqw3IlxWfxRf6HLQrXTNpixxJZhOSPWWrum2NqBolaxP1i00lrKuKAURKq1po6OtYJtWgnk3oIgZKExJSTU1RWStFarhiTf+WPv2xyuN3CSe+65N2e/X2uddfZ+9o/z9ZrL5z7Pec5zUlVIkqTRdtRcFyBJkmafgS9JUgcY+JIkdYCBL0lSBxj4kiR1wMK5LmC2nHDCCbVixYq5LkOSpKGZmJh4pKpOnO7YyAb+ihUrGB8fn+syJEkamiRbD3bMIX1JkjrAwJckqQMMfEmSOsDAlySpA4Ya+EnOSfJAks1JLpnm+PIkn0tyT5LPJzm559gFSb7SPi4YZt2SJB3phhb4SRYAVwOvAs4Azk9yxpTTfhv4k6p6DnA58N722iXApcDzgJXApUkWD6t2SZKOdMPs4a8ENlfVlqraA9wMnDflnDOAW9vt23qOvxL4bFXtqqrdwGeBc4ZQsyR1x7Y7Yf3vNM8aOcP8HP5JwLae/e00PfZeXwReB1wFvBZ4SpLjD3LtSVNfIMlFwEUAy5YtG1jhkjTytt0J158L+/bAgkVwwTpYunKuq9IAzbeFd34Z+ECStcDfAg8D+/q9uKquAa4BGBsbq9koUJKOZEn6O/HXp/bHHqvK/8QeaYY5pP8wsLRn/+S27T9U1T9V1euq6rnAr7Vt/9LPtZKkJ1ZV0z8euoN611Obc9711Gb/YOca9kekYQb+RuC0JKckWQSsBtb1npDkhCSTNb0D+FC7/RngFUkWt5P1XtG2SZIGYenKZhgfHM4fUUML/KraC1xME9SbgFuq6r4klyc5tz3txcADSb4MPBV4T3vtLuBdNH80bAQub9skSYMyGfKG/UjKqA7NjI2NlV+eI0mHJolD9kewJBNVNTbdMVfakySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SVJj252PfdZIMfAlSU3IX39us339uYb+CBpq4Cc5J8kDSTYnuWSa48uS3JbkriT3JHl1274iybeS3N0+PjjMuiVp5D24Hvbtabb37Wn2NVIWDuuFkiwArgZeDmwHNiZZV1X395z2TuCWqvrDJGcAnwJWtMe+WlVnDqteSeqUFWfDgkXN9oJFzb5GyjB7+CuBzVW1par2ADcD5005p4Bj2+3jgH8aYn2S1F1LV8IF65rtC9Y1+xopwwz8k4BtPfvb27ZelwH/Lcl2mt79W3qOndIO9f9Nkmn/9ExyUZLxJOM7duwYYOmS1AGTIW/Yj6T5NmnvfOC6qjoZeDVwQ5KjgK8By6rqucAvATclOXbqxVV1TVWNVdXYiSeeONTCJUmaz4YZ+A8DS3v2T27bev08cAtAVd0OHAOcUFXfqaqdbfsE8FXg6bNesSRJI2KYgb8ROC3JKUkWAauBdVPOeQh4KUCS02kCf0eSE9tJfyQ5FTgN2DK0yiVJOsINLfCrai9wMfAZYBPNbPz7klyepP3wJ28H/nuSLwIfBtZWVQEvBO5JcjfwUeBNVbVrWLVLUieMX/fYZ42UoX0sD6CqPkUzGa+37X/3bN8PvGCa6z4GfGzWC5Skrhq/Dj751mZ78nls7VxVo1mQpgM9esbGxmp8fHyuy5CkoVqyZAm7d++e0xoWL17Mrl0Ows6FJBNVNTbdsaH28CVJs2v37t0cVkeut4cP8JqrDruHn+SwrtPsMvAlSQfCfdMn4PTzHM4fQQa+JKkxttagH2HzbeEdSZI0Cwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAwx8SVJj/Dq44bXNs0aOX48rSWpC/pNvbba/emvz7FfljpRU1VzXMCvGxsZqfHx8rsuQpOG67Li5rqBx2b/OdQWdlGSiqsamO2YPX5JGSH7j6xxWR663hw/wmqsOu4efhLrssC7VLDLwJUkHwn3TJ+D08xzOH0EGviSpMbbWoB9hfc3ST/ITSRbMdjGSJGl29PuxvBuBh5NckeTps1mQJEkavH4D/weAS4EXAZuS/F2SNyb53tkrTZIkDUpfgV9V/1ZVf1RVq4DnAHcA7wW+luSPk6yazSIlSdLMHPJKe1V1H3AlcA2wCPgZYH2SO5I8Z8D1qYMmtu7m6ts2M7F191yXInXLtjth/e80zxo5fc/ST3I08Frg54CX0vTy3wR8BFgM/Ga7ffrgy1RXTGzdzZprN7Bn734WLTyKGy9cxVnLF891WdLo23YnXH8u7NsDCxbBBetg6cq5rkoD1FfgJ/l94HyggBuAX6qq+3tO+VaSS4B/GnyJGkVJ+jpv7N0HPzaqq0RKc+LB9U3Y177m+cH1Bv6I6XdI/wzgYuCkqpoa9pMeAV4ysMo00qpq2sf4g7t4xjs/BcAz3vkpxh/cddBzJQ3QirObnn0WNM8rzp7rijRgrqWveWdi627GVixh/MFdDudLhyjJ4f9BvO3Opme/4uwZ9e5nVINmZMZr6Sd5D7Ctqj44pf1NNL3+X595mVJjMuQNe2nIlq50GH+E9Tuk/wbgrmnaJ4Cf7ffFkpyT5IEkm9v3/KceX5bktiR3Jbknyat7jr2jve6BJK/s9zUlSVL/s/S/H9gxTftO4Kn93KBdmvdq4OXAdmBjknVT5gO8E7ilqv4wyRnAp4AV7fZq4FnAfwb+OsnTq2pfn/VLktRp/fbwHwKmm8HxQprw7sdKYHNVbamqPcDNwHlTzing2Hb7OA7M+j8PuLmqvlNV/whsbu8nSZL60G8P/4+AK5MsAm5t215Ks9reFX3e4yRgW8/+duB5U865DPirJG8Bvhd4Wc+1G6Zce9LUF0hyEXARwLJly/osS5Kk0ddX4FfV7yQ5Afg9mtX1APYAV1XV+wdYz/nAde3rPR+4Icmz+724qq6hWQGQsbExp4hKktTqe6W9qnpHknfTfCYfYFNVfeMQXuthYGnP/sltW6+fB85pX+/2JMcAJ/R5rSRJOohDWku/qr5ZVRvbx6GEPcBG4LQkp7RvDawG1k055yGatwpIcjpwDM1kwXXA6iTfk+QU4DTAxZ4lSerToayl/xKaIfdlHBjWB6CqfuyJrq+qvUkuBj4DLAA+VFX3JbkcGK+qdcDbgT9O8jaaCXxrq1m94b4ktwD3A3uBNztDX5Kk/vW10l6StcAHgY/TfIHOJ4CnA6cAf1pVF89ijYfFlfaObK7UJR2e+fC7Mx9q6KrHW2mv3yH9XwYurqrzgUeBd1TVc4E/BQ51aF+SJA1Zv4F/KvDX7fZ3gCe32x8A1g64JkmSNGD9Bv5O4Cnt9sPA5EfljgeeNOiiJEnSYPU7aW898ArgXuAW4PeSvJxmRv1nZ6k2SZI0IP0G/sU0H5GDZnW9vcALaML/3bNQlyRJGqAnDPwkC2k+M/8XAFW1n/6X05UkSfPAE76HX1V7gd8Cjp79ciRJ0mzod9LeBuCs2SxEkiTNnn7fw/9j4LeTLAMmgG/2Hqyqfxh0YZIkaXD6Dfyb2uffneZY0SyVK0mS5ql+A/+UWa1CkiTNqr4Cv6q2znYhkiRp9vQV+Ele93jHq+rPB1OOJEmaDf0O6X/0IO2TX4fke/gamImtu//j+azli+e4GunIk2ROX3/xYn9v56O+PpZXVUf1PoBFwPNoltx94WwWqG6Z2LqbNdduAGDNtRv+I/wl9aeqZvQYxD127do1xz8FTaffz+E/RlXtraqNwK8CfzDYktRlG7bsZM/e/QA8unc/G7bsnOOKJGk0HFbg9/gX4AcHUIcEwKpTj2fRwuaf5dELj2LVqcfPcUWSNBoyOYTzuCclPzK1CXga8CsAVXX24EubmbGxsRofH5/rMnQYJrbuZmzFEsYf3OV7+NKQJaGfXND8lGSiqsamO9bvpL1xmgl6U2eCbADeOIPapO8yGfKGvSQNzuEuvLMf2FFV3x5wPZIkaRa48I4kqbHtzgPPS1fObS0auL4m7SV5T5I3TdP+piTvGnxZkqSh2nYnXH9us339uQfCXyOj31n6bwDumqZ9AvjZwZUjPXbhHUlD8uB62Nu+S7v3282+Rkq/gf/9wI5p2ncCTx1cOeo6F96R5si3v86BxVOr3dco6XfS3kPA2cCWKe0vBLYPtCJ12nQL7zhbXxqcfpbdzW98Hbi8fUzPj+4defoN/D8CrkyyCLi1bXsp8F7gitkoTN3kwjvS7DpoUI9fB59864H911wFY2uHUZKGpK+FdwCSvBf4RZp19AH2AFdV1SWzU9rMuPDOkcuFd6Q5Mn4dbPoEnH6eYX+EeryFd/oO/PZG3wuc0e5uqqpvDKC+WWHgH9lc7UuSDt2MV9pL8gPAwqraDmzsaT8ZeLSq/nkglUqSpFnR7yz9PwVeNU37K4EbBleORsGSJUtIMqMHMON7LFmyZI5/EpI0f/Q7aW8MePM07euB3xpcORoFu3fvnhfD8f3MRpakrui3h78Q+J5p2o85SLskSZpH+g38O4D/MU37m+l5T1+SJM1P/Q7p/xpwa5LncOBz+D8G/AjN5/H7kuQc4CpgAXBtVb1vyvErgZe0u/8J+P6q+r722D7g3vbYQ1V1br+vK0lS1/X7bXkbkjwf+F/A69rmfwD+J3BiP/dIsgC4Gng5zep8G5Osq6r7e17nbT3nvwV4bs8tvlVVZ/bzWpIk6bH6HdKnqr5YVWuq6lk0s/O/DHwc+Eyft1gJbK6qLVW1B7gZOO9xzj8f+HC/9UmSpIPrO/CTLEjyuiT/F/hH4CeADwI/1OctTgK29exvb9ume63lwCkcePsA4Jgk40k2JPmJg1x3UXvO+I4d033XjyRJ3fSEQ/pJngFcSPM1uN8EbqLp4b+hdzh+wFYDH62qfT1ty6vq4SSn0swnuLeqvtp7UVVdA1wDzUp7s1SbJElHnMft4SdZD2wAFgM/XVWnVtU7OfAdiofiYWBpz/7Jbdt0VjNlOL+qHm6ftwCf57Hv70uSpMfxREP6zwf+BLiyqv5mhq+1ETgtySntt+6tBtZNPSnJM2n+wLi9p21xku9pt08AXgDM1uiCJEkj54kC/7/QDPv/XZK7krytXVf/kFXVXuBimkl+m4Bbquq+JJcn6f2I3Wrg5nrsUm2nA+NJvgjcBrxvFt9OkCRp5PT1bXlJjgF+Cvg54Edp/lC4hOaz9LtntcLD5LflzZ2ZftPdTXc8xKe/9DVe9eyn8frnLZuzOiTpSDPjb8urqm/TfEnODUl+iGYS39uAdye5taqm+2Id6ZDddMdD/OrHm/WV1n/lEYAZhb4kqdH3x/ImVdXmqrqEZgLeTwN7Bl6VOuvTX/ra4+5Lkg7PIQf+pKraV1WfqKrHWzxHOiSvevbTHndfknR4+l1LXxqKyeH7QbyHL0k6wMDXvPP65y0z6CVpwA57SF+SJB057OFr4OrSY+Gy4+a6jKYOSRJg4GsW5De+Pi8+/56Eumyuq5Ck+cEhfc07E1t3c/Vtm5nYOi/XdJKkI5I9fM0rE1t3s+baDezZu59FC4/ixgtXcdbyxXNdliQd8ezha17ZsGUne/buZ3/Bo3v3s2HLzrkuSZJGgoGveWXVqcezaOFRLAgcvfAoVp16/FyXJEkjwSF9zStnLV/MjReuYsOWnaw69XiH8yVpQAx8zTtnLV9s0EvSgDmkL0lSBxj4kiR1gIEvSVIHGPiSJHWAga9556Y7HuIN/+cObrrjobkuRZJGhrP0Na/cdMdD/OrH7wVg/VceAfCrciVpAAx8zYokA7nPmitgzWFeu3ixH+2TpEkO6WvgquqwHzdu2MryX/kkAMt/5ZPcuGHrYd9r165dc/yTkKT5wx6+5pXJ4fs1V8BvvvaHHc6XpAGxh695ZzLkDXtJGhwDX5KkDjDwJUnqAANfkqQOMPAlSeoAA1+SpA4w8CVJ6gADX5KkDjDwJUnqAANfkqQOGGrgJzknyQNJNie5ZJrjVya5u318Ocm/9By7IMlX2scFw6xbkqQj3dDW0k+yALgaeDmwHdiYZF1V3T95TlW9ref8twDPbbeXAJcCY0ABE+21u4dVvyRJR7Jh9vBXApuraktV7QFuBs57nPPPBz7cbr8S+GxV7WpD/rPAObNarSRJI2SYgX8SsK1nf3vb9l2SLAdOAW49lGuTXJRkPMn4jh07BlK0JEmjYL5O2lsNfLSq9h3KRVV1TVWNVdXYiSeeOEulSZJ05Blm4D8MLO3ZP7ltm85qDgznH+q1kiRpimEG/kbgtCSnJFlEE+rrpp6U5JnAYuD2nubPAK9IsjjJYuAVbZskSerD0GbpV9XeJBfTBPUC4ENVdV+Sy4HxqpoM/9XAzVVVPdfuSvIumj8aAC6vql3Dql2SpCNdenJ1pIyNjdX4+Phcl6HDlIRR/bcpSbMlyURVjU13bL5O2pMkSQNk4EuS1AEGviRJHWDgS5LUAQa+JEkdYOBLktQBBr4kSR1g4EuS1AEGviRJHWDgS5LUAQa+JEkdYOBLktQBBr4kSR1g4EuS1AEGviRJHWDgS5LUAQa+JEkdYOBLktQBBr4kSR1g4EuS1AEGvuadia27H/MsSZo5A1/zysTW3ay5dgMAa67dYOhL0oAY+JpXNmzZyZ69+wF4dO9+NmzZOccVSdJoMPA1r6w69XgWHhUAFhwVVp16/BxXJEmjwcDX/JM89lmSNGML57oAdVP6CPOvvOfVjL3n4MeraoAVSdJoM/A1Jw4W1pOT9h7du5+jFx7FjReu4qzli4dcnSSNHgNf88pZyxdz44Wr2LBlJ6tOPd6wl6QBMfA175y1fLFBL0kD5qQ9SZI6wMCXJKkDDHxJkjrAwJckqQMMfEmSOsDAlySpAzKqq5Ul2QFsnes6dNhOAB6Z6yKkDvJ378i2vKpOnO7AyAa+jmxJxqtqbK7rkLrG373R5ZC+JEkdYOBLktQBBr7mq2vmugCpo/zdG1G+hy9JUgfYw5ckqQMMfEmSOsDA19Al+bUk9yW5J8ndSS5N8t4p55yZZFO7/WCS9VOO353kS8OsW5oNSfZN/ntO8pdJvm9A912b5AODuNeU+34+yQNtzXcn+a+Dfo32dVYkef1s3LurDHwNVZLnA68BfqSqngO8DLgN+Jkpp64GPtyz/5QkS9t7nD6MWqUh+VZVnVlVzwZ2AW+e64L6sKat+cyq+mg/FyRZeIivsQIw8AfIwNewPQ14pKq+A1BVj1TV3wK7kzyv57yf5rGBfwsH/ig4f8oxaVTcDpwEkGRlktuT3JXkC0me0bavTfLnSf5fkq8kef/kxUnemOTLSe4EXtDTviLJre2o2ueSLGvbr0vyh0k2JNmS5MVJPpRkU5Lr+i06yZIkf9Hef0OS57TtlyW5IcnfAzckOTHJx5JsbB8vaM97Uc+IwV1JngK8Dzi7bXvbTH+wAqrKh4+hPYAnA3cDXwb+AHhR2/7LwJXt9ipgvOeaB4FnAF9o9+8CzgC+NNf/e3z4mOkD+Eb7vAD4M+Ccdv9YYGG7/TLgY+32WmALcBxwDM0S4ktp/ph+CDgRWAT8PfCB9pq/BC5ot38O+It2+zrgZiDAecDXgR+m6QxOAGdOU+/ngQfa3+O7geOB3wcubY//GHB3u31Ze58ntfs3AT/abi8DNvXU94J2+8nAQuDFwCfn+v+fUXoc6hCLNCNV9Y0kZwFnAy8BPpLkEuAjwBeSvJ3vHs4H2EkzCrAa2AT8+xDLlmbTk5LcTdOz3wR8tm0/Drg+yWlAAUf3XPO5qvpXgCT3A8tp1sD/fFXtaNs/Ajy9Pf/5wOva7RuA9/fc6y+rqpLcC/xzVd3bXn8fzbD63dPUvKaqxid3kvwo8JMAVXVrkuOTHNseXldV32q3XwackWTy0mOTPJnmj5PfTXIj8OdVtb3nHA2IQ/oauqraV1Wfr6pLgYuBn6yqbcA/Ai+i+Q/HR6a59CPA1Ticr9Hyrao6kya0w4H38N8F3FbNe/s/TtObn/Sdnu19MKPO2+S99k+57/4Z3nfSN3u2jwJW1YH3/0+qqm9U1fuAC4EnAX+f5JkDeF1NYeBrqJI8o+2xTDqTA99q+GHgSmBLVW2f5vKP0/RMPjOrRUpzoKr+HfgF4O3tBLfjgIfbw2v7uMUdwIva3vXRwE/1HPsCzcgZwBpg/dSLZ2h9e1+SvJhmns7Xpznvr4C3TO4kObN9/sGqureqrgA2As8E/g14yoDr7DQDX8P2ZJphyvuT3EPzXvxl7bE/A57FQXrwVfVvVXVFVe0ZSqXSkFXVXcA9NBNT3w+8N8ld9NHTrqqv0fwu3U4zRL6p5/BbgDe2v3NvAN462Mq5DDirvf/7gAsOct4vAGPt5L77gTe17b/YfizxHuBR4NM0P4d9Sb7opL3BcGldSZI6wB6+JEkdYOBLktQBBr4kSR1g4EuS1AEGviRJHWDgS5LUAQa+JEkd8P8BRwxIZRxqTYcAAAAASUVORK5CYII=\n"
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"plt.figure(figsize=(8, 4))\n",
"plt.plot([1]*10, svm_scores, '.')\n",
"plt.plot([2]*10, forest_scores, '.')\n",
"plt.boxplot([svm_scores, forest_scores], labels=('SVM', 'Random Forest'))\n",
"plt.ylabel('Accuracy', fontsize=14)\n",
"plt.show()"
]
},
{
"source": [
"为了进一步改善结果,可以进行如下操作:\n",
"- 对更多模型,使用cross validation和grid search调整超参数\n",
"- 使用更多的特征工程,例如:\n",
" - 是**SibSp**和**Parch**的和代替他们\n",
" - 尝试找出与**Survived**属性很好相关的部分\n",
"- 尝试把年龄属性更改为年龄段属性\n"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Survived\n",
"AgeBucket \n",
"0.0 0.576923\n",
"15.0 0.362745\n",
"30.0 0.423256\n",
"45.0 0.404494\n",
"60.0 0.240000\n",
"75.0 1.000000"
],
"text/html": "\n\n
\n \n \n | \n Survived | \n
\n \n AgeBucket | \n | \n
\n \n \n \n 0.0 | \n 0.576923 | \n
\n \n 15.0 | \n 0.362745 | \n
\n \n 30.0 | \n 0.423256 | \n
\n \n 45.0 | \n 0.404494 | \n
\n \n 60.0 | \n 0.240000 | \n
\n \n 75.0 | \n 1.000000 | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 44
}
],
"source": [
"train_data[\"AgeBucket\"] = train_data[\"Age\"] // 15 * 15\n",
"train_data[[\"AgeBucket\", \"Survived\"]].groupby(['AgeBucket']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Survived\n",
"RelativesOnboard \n",
"0 0.303538\n",
"1 0.552795\n",
"2 0.578431\n",
"3 0.724138\n",
"4 0.200000\n",
"5 0.136364\n",
"6 0.333333\n",
"7 0.000000\n",
"10 0.000000"
],
"text/html": "\n\n
\n \n \n | \n Survived | \n
\n \n RelativesOnboard | \n | \n
\n \n \n \n 0 | \n 0.303538 | \n
\n \n 1 | \n 0.552795 | \n
\n \n 2 | \n 0.578431 | \n
\n \n 3 | \n 0.724138 | \n
\n \n 4 | \n 0.200000 | \n
\n \n 5 | \n 0.136364 | \n
\n \n 6 | \n 0.333333 | \n
\n \n 7 | \n 0.000000 | \n
\n \n 10 | \n 0.000000 | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 77
}
],
"source": [
"train_data[\"RelativesOnboard\"] = train_data[\"SibSp\"] + train_data[\"Parch\"]\n",
"train_data[[\"RelativesOnboard\", \"Survived\"]].groupby(['RelativesOnboard']).mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}