4.3 多项式回归¶
[1]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('ggplot')
m = 100
X = 6 * np.random.rand(m, 1) -3
y = 0.5 * (X ** 2) + X + 2 + np.random.randn(m, 1)
plt.scatter(X, y)
plt.show()
[2]:
from sklearn.preprocessing import PolynomialFeatures
poly_features = PolynomialFeatures(degree=2, include_bias=False)
X_ploy = poly_features.fit_transform(X)
X[0], X_ploy[0]
[2]:
(array([0.79304176]), array([0.79304176, 0.62891523]))
[3]:
from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X_ploy, y)
lin_reg.intercept_, lin_reg.coef_
[3]:
(array([2.22582176]), array([[1.0596753 , 0.44806298]]))
[4]:
X_draw = np.linspace(-3, 3, 20)
y_hat = lin_reg.coef_[0][1] * (X_draw**2) + lin_reg.coef_[0][0] * X_draw + lin_reg.intercept_
plt.scatter(X, y)
plt.plot(X_draw, y_hat, 'b')
plt.show()
如果有两个特征\(a\)和\(b\), 则degree=3
的PolynomialFeatures
不仅会添加特征\(a^2\)、\(a^3\)、\(b^2\)、\(b^3\), 还会添加组合\(ab\)、\(a^2b\)和\(ab^2\)。PolynomialFeatures(degree=d)
可以将包含\(n\)个特征的数组转换为包含\(\frac{(n+d)!}{d!n!}\)的特征的数组。