4

为了简单起见,我将 500 个样本划分为 10,000 多行数据集。请将 X 和 y 复制并粘贴到您的 IDE 中。

X =

array([ -8.93,  -0.17,   1.47,  -6.13,  -4.06,  -2.22,  -2.11,  -0.25,
         0.25,   0.49,   1.7 ,  -0.77,   1.07,   5.61, -11.95,  -3.8 ,
        -3.42,  -2.55,  -2.44,  -1.99,  -1.7 ,  -0.98,  -0.91,  -0.91,
        -0.25,   1.7 ,   2.88,  -6.9 ,  -4.07,  -1.35,  -0.33,   0.63,
         0.98,  -3.31,  -2.61,  -2.61,  -2.17,  -1.38,  -0.77,  -0.25,
        -0.08,  -1.2 ,  -3.1 ,  -1.07,  -0.7 ,  -0.41,  -0.33,   0.41,
         0.77,   0.77,   1.14,   2.17,  -7.92,  -3.8 ,  -2.11,  -2.06,
        -1.2 ,  -1.14,   0.  ,   0.56,   1.47,  -1.99,  -0.17,   2.44,
        -5.87,  -3.74,  -3.37,  -2.88,  -0.49,  -0.25,  -0.08,   0.33,
         0.33,   0.84,   1.64,   2.06,   2.88,  -4.58,  -1.82,  -1.2 ,
         0.25,   0.25,   0.63,   2.61,  -5.36,  -1.47,  -0.63,   0.  ,
         0.63,   1.99,   1.99, -10.44,  -2.55,   0.33,  -8.93,  -5.87,
        -5.1 ,  -2.78,  -0.25,   1.47,   1.93,   2.17,  -5.36,  -5.1 ,
        -3.48,  -2.44,  -2.06,  -2.06,  -1.82,  -1.58,  -1.58,  -0.63,
        -0.33,   0.  ,   0.17,  -3.31,  -0.25,  -5.1 ,  -3.8 ,  -2.55,
        -1.99,  -1.7 ,  -0.98,  -0.91,  -0.63,  -0.25,   0.77,   0.91,
         0.91,  -9.43,  -8.42,  -2.72,  -2.55,  -1.26,   0.7 ,   0.77,
         1.07,   1.47,   1.7 ,  -1.82,  -1.47,   0.17,   1.26,  -5.36,
        -1.52,  -1.47,  -0.17,  -3.48,  -3.31,  -2.06,  -1.47,   0.17,
         0.25,   1.7 ,   2.5 ,  -9.94,  -6.08,  -5.87,  -3.37,  -2.44,
        -2.17,  -1.87,  -0.98,  -0.7 ,  -0.49,   0.41,   1.47,   2.28,
       -14.95, -12.44,  -6.39,  -4.33,  -3.8 ,  -2.72,  -2.17,  -1.2 ,
         0.41,   0.77,   0.84,   2.51,  -1.99,  -1.7 ,  -1.47,  -1.2 ,
         0.49,   0.63,   0.84,   0.98,   1.14,   2.5 ,  -2.06,  -1.26,
        -0.33,   0.17,   4.58,  -7.41,  -5.87,   1.2 ,   1.38,   1.58,
         1.82,   1.99,  -6.39,  -2.78,  -2.67,  -1.87,  -1.58,  -1.47,
         0.84, -10.44,  -7.41,  -3.05,  -2.17,  -1.07,  -1.07,  -0.91,
         0.25,   1.82,   2.88,  -6.9 ,  -1.47,   0.33,  -8.42,  -3.8 ,
        -1.99,  -1.47,  -1.47,  -0.56,   0.17,   0.17,   0.25,   0.56,
         4.58,  -3.48,  -2.61,  -2.44,  -0.7 ,   0.63,   1.47,   1.82,
       -13.96,  -9.43,  -2.67,  -1.38,  -0.08,   0.  ,   1.82,   3.05,
        -4.58,  -3.31,  -0.98,  -0.91,  -0.7 ,   0.77,  -0.7 ,  -0.33,
         0.56,   1.58,   1.7 ,   2.61,  -4.84,  -4.84,  -4.32,  -2.88,
        -1.38,  -0.98,  -0.17,   0.17,   0.49,   2.44,   4.32,  -3.48,
        -3.05,   0.56,  -8.42,  -3.48,  -2.61,  -2.61,  -2.06,  -1.47,
        -0.98,   0.  ,   0.08,   1.38,   1.93,  -9.94,  -2.72,  -1.87,
        -1.2 ,  -1.07,   1.58,   4.58,  -6.64,  -2.78,  -0.77,  -0.7 ,
        -0.63,   0.49,   1.07,  -8.93,  -4.84,  -1.7 ,   1.76,   3.31,
       -11.95,  -3.16,  -3.05,  -1.82,  -0.49,  -0.41,   0.56,   1.58,
       -13.96,  -3.05,  -2.78,  -2.55,  -1.7 ,  -1.38,  -0.91,  -0.33,
         1.2 ,   1.32,   1.47,  -2.06,  -1.82,  -7.92,  -6.33,  -4.32,
        -3.8 ,  -1.93,  -1.52,  -0.98,  -0.49,  -0.33,   0.7 ,   1.52,
         1.76,  -8.93,  -7.41,  -2.88,  -2.61,  -2.33,  -1.99,  -1.82,
        -1.64,  -0.84,   1.07,   2.06,  -3.96,  -2.44,  -1.58,   0.  ,
        -3.31,  -2.61,  -1.58,  -0.25,   0.33,   0.56,   0.84,   1.07,
        -1.58,  -0.25,   1.35,  -1.99,  -1.7 ,  -1.47,  -1.47,  -0.84,
        -0.7 ,  -0.56,  -0.33,   0.56,   0.63,   1.32,   2.28,   2.28,
        -2.72,  -0.25,   0.41,  -6.9 ,  -4.42,  -4.32,  -1.76,  -1.2 ,
        -1.14,  -1.07,   0.56,   1.32,   1.52, -14.97,  -7.41,  -5.1 ,
        -2.61,  -1.93,  -0.98,   0.17,   0.25,   0.41,  -4.42,  -2.61,
        -0.91,  -0.84,   2.39,  -2.61,  -1.32,   0.41,  -6.9 ,  -5.61,
        -4.06,  -3.31,  -1.47,  -0.91,  -0.7 ,  -0.63,   0.33,   1.38,
         2.61,  -2.29,   3.06,   4.44, -10.94,  -4.32,  -3.42,  -2.17,
        -1.7 ,  -1.47,  -1.32,  -1.07,  -0.7 ,   0.  ,   0.77,   1.07,
        -3.31,  -2.88,  -2.61,  -1.47,  -1.38,  -0.63,  -0.49,   1.07,
         1.52,  -3.8 ,  -1.58,  -0.91,  -0.7 ,   0.77,   3.42,  -8.42,
        -2.88,  -1.76,  -1.76,  -0.63,  -0.25,   0.49,   0.63,  -6.9 ,
        -4.06,  -1.82,  -1.76,  -1.76,  -1.38,  -0.91,  -0.7 ,   0.17,
         1.38,   1.47,   1.47, -11.95,  -0.98,  -0.56, -14.97,  -9.43,
        -8.93,  -2.72,  -2.61,  -1.64,  -1.32,  -0.56,  -0.49,   0.91,
         1.2 ,   1.47,  -3.8 ,  -3.06,  -2.51,  -1.04,  -0.33,  -0.33,
        -3.31,  -3.16,  -3.05,  -2.61,  -1.47,  -1.07,   2.17,   3.1 ,
        -2.61,  -0.25,  -3.85,  -2.44])

y =

array([1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0,
       1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1,
       0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1,
       1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1,
       1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
       1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,
       0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
       0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,
       1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0,
       0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1,
       1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0,
       1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
       0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0,
       1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1,
       1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
       0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0,
       0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0,
       0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0,
       0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0,
       1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1])

初始化和培训:

from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X, y)

交叉验证:

from sklearn.model_selection import cross_val_score
cross_val_score(model, X, y, cv=10, scoring='r2').mean()

-0.3339677563815496(负R2?)

看它是否接近模型的真实 R2。我这样做了:

from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=None, shuffle=False)

r2_score(y_test, model.predict_proba(X_test)[:,1], multioutput='variance_weighted')

0.32642659661798396

这个 R2 对模型的拟合优度更有意义,看起来两个 R2 只是一个 +/- 符号开关,但事实并非如此。在我使用更大样本的模型中,R2 cross-val 为 -0.24,R2 test 为 0.18。而且,当我添加一个似乎有利于模型的功能时,R2 测试上升,R2 交叉验证下降

此外,如果您将 LogisticRegression 切换为 LinearRegression,R2 交叉验证现在为正,并且接近 R2 测试。是什么导致了这个问题?

4

2 回答 2

7

TLDR:R2 可能是负面的,你误解了train_test_split结果。

我将在下面解释这两种说法。

cross_val_score标志翻转errorloss指标

文档中,您可以看到cross_val_score实际上翻转了某些指标的符号。但仅适用于errororloss指标(越低越好),而不适用于score指标(越高越好):

所有 scorer 对象都遵循较高返回值优于较低返回值的约定。因此,衡量模型和数据之间距离的指标,如 metrics.mean_squared_error,可作为 neg_mean_squared_error 使用,它返回指标的否定值。

由于r2是一个score指标,它不会翻转符号。你得到一个-0.33交叉验证。请注意,这是正常的。来自r2_score 文档

最好的分数是 1.0,它可以是负数(因为模型可以任意变坏)。始终预测 y 的期望值的常量模型,不考虑输入特征,将获得 0.0 的 R^2 分数。

所以这将我们引向第二部分:为什么使用 CV 和训练/测试拆分得到如此不同的结果?

CV 和训练/测试拆分结果之间的差异

使用 .获得更好结果的原因有两个train_test_split

评估r2概率而不是类(您正在使用predict_proba而不是减少predict错误的危害:

print(r2_score(y_test, model.predict_proba(X_test)[:,1], multioutput='variance_weighted'))
 0.19131536389654913

尽管:

 print(r2_score(y_test, model.predict(X_test)))
 -0.364200082678793

10折叠 cv 的平均值,不检查方差,这很高。如果你检查方差和结果的细节,你会发现方差很大:

scores = cross_val_score(model, X, y, cv=10, scoring='r2')
scores
array([-0.67868339, -0.03918495,  0.04075235, -0.47783251, -0.23152709,
   -0.39573071, -0.72413793, -0.66666667,  0.        , -0.16666667])

scores.mean(), scores.std() * 2
(-0.3339677563815496, 0.5598543351649792)

希望它有所帮助!

于 2018-11-22T03:52:13.577 回答
3

R2 可能为负数。以下段落来自“确定系数”的维基百科页面

在某些情况下,R2 的计算定义可能会产生负值,具体取决于所使用的定义。当与相应结果进行比较的预测没有从使用这些数据的模型拟合过程中得出时,就会出现这种情况。即使使用了模型拟合程序,R2 仍可能为负值,例如当进行线性回归而不包括截距时,或使用非线性函数拟合数据时。根据这一特定标准,在出现负值的情况下,数据的平均值比拟合函数值更适合结果。由于决定系数的最一般定义也称为 Nash-Sutcliffe 模型效率系数,因此最后一种表示法在许多领域中是首选的,

似乎预测比水平线更糟糕。

于 2018-11-22T03:40:10.453 回答