0

我有一个问题,我正在尝试实施KFoldand cross_val_score。我的目标是计算mean_squared_error,为此我使用了以下代码:

from sklearn import linear_model
import numpy as np
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold, cross_val_score

x = np.random.random((10000,20))
y = np.random.random((10000,1))

x_train = x[7000:]
y_train = y[7000:]

x_test = x[:7000]
y_test = y[:7000]

Model = linear_model.LinearRegression()
Model.fit(x_train,y_train)

y_predicted  = Model.predict(x_test)

MSE = mean_squared_error(y_test,y_predicted)
print(MSE)

kfold = KFold(n_splits = 100, random_state = None, shuffle = False)

results = cross_val_score(Model,x,y,cv=kfold, scoring='neg_mean_squared_error')
print(results.mean())

我认为这里没问题,我得到了以下结果:

结果:0.0828856459279和 -0.083069435946

但是当我尝试在其他示例(来自 Kaggle 房价的数据)上执行此操作时,它无法正常工作,至少我认为是这样。

train = pd.read_csv('train.csv')

Insert missing values...
...

train = pd.get_dummies(train)
y = train['SalePrice']
train = train.drop(['SalePrice'], axis = 1)

x_train = train[:1000].values.reshape(-1,339)
y_train = y[:1000].values.reshape(-1,1)
y_train_normal = np.log(y_train)

x_test = train[1000:].values.reshape(-1,339)
y_test = y[1000:].values.reshape(-1,1)

Model = linear_model.LinearRegression()
Model.fit(x_train,y_train_normal)

y_predicted = Model.predict(x_test)
y_predicted_transform = np.exp(y_predicted)

MSE = mean_squared_error(y_test, y_predicted_transform)
print(MSE)

kfold = KFold(n_splits = 10, random_state = None, shuffle = False)

results = cross_val_score(Model,train,y, cv = kfold, scoring = "neg_mean_squared_error")
print(results.mean())

在这里,我得到以下结果:0.912874946869-6.16986926564e+16

显然,mean_squared_error“手动”mean_squared_error计算与通过KFold.

我对我在哪里犯错感兴趣?

4

1 回答 1

0

差异是因为,与您的第一种方法(训练/测试集)相比,在您的 CV 方法中,您使用非标准化 y数据来拟合回归,因此您的 MSE 很大。要获得可比较的结果,您应该执行以下操作:

y_normal = np.log(y)
y_test_normal = np.log(y_test)

MSE = mean_squared_error(y_test_normal, y_predicted) # NOT y_predicted_transform
results = cross_val_score(Model, train, y_normal, cv = kfold, scoring = "neg_mean_squared_error")
于 2018-02-09T16:10:51.797 回答