1

我正在为泰坦尼克号的案例执行 Kaggle 的学习任务。

如果我手动分离数据或使用 cross_val_score 执行线性回归,我有不同的预测精度。逻辑回归也是如此。

例子。

- 线性回归。

手动的

Algorithm = LinearRegression()
kf = KFold(dataset.shape[0], n_folds=3, random_state=1)
predictions = []

for train, test in kf:

    train_predictors = (dataset[Predictors].iloc[train])
    train_target = dataset['Survived'].iloc[train]
    Algorithm.fit(train_predictors, train_target)
    test_predictions = Algorithm.predict(dataset[Predictors].iloc[test])
    predictions.append(test_predictions)

predictions = np.concatenate(predictions, axis=0)
print(predictions.shape[0])
realed = list(dataset.Survived)
predictions[predictions > 0.5] = 1
predictions[predictions <= 0.5] = 0

accuracy2 = sum(predictions[predictions == dataset["Survived"]]) / len(predictions)
print("Tochnost prognoza: ", accuracy2 * 100, " %")

结果 - 78,34%

Cross_val_score

scores=cross_val_score(LinearRegression(), dataset[Predictors], dataset["Survived"], cv=3)
print(scores.mean())

结果 - 37.5%

- 逻辑回归。

在这里,我有 26.15% 的手动功能和 78.78% 的 cross_val_score 功能。

为什么??

4

1 回答 1

3

您的代码有几处看起来非常错误。

  1. 你的准确度计算是错误的
    这一行:

    accuracy2 = sum(predictions[predictions == dataset["Survived"]]) / len(predictions)
    

    不计算准确度。当你有正确的预测时,它的作用是取你所做的预测的平均值。这没有多大意义;)。
    不过,这很容易解决:

    accuracy2 = sum(predictions == dataset["Survived"] / len(predictions)
    
  2. 线性回归实际上执行回归
    使用线性回归来执行分类任务并不是一个好主意。在(二进制)分类中,您期望输出范围为 [0; 1](概率),而线性回归通常会给你一个无限的范围。
    由于统计学家是线性回归的忠实拥护者,他们发明了逻辑回归,这实际上是对转换后的目标值的线性回归。
    底线:使用逻辑回归(不是线性回归)进行分类。

  3. 评分方法不是你想的那样
    cross_val_score接受一个scoring参数。在这里您没有指定它(所以它是None),这意味着它将查找估计器的默认得分方法。的默认评分方法LinearRegression 是 not accuracy。它是 R^2 系数。这与回归有关,而不是真正与您尝试做的事情有关。

    所以当你这样做时:

    scores=cross_val_score(LinearRegression(), dataset[Predictors], dataset["Survived"], cv=3)
    print(scores.mean())
    

    您得到的是 3 倍交叉验证的平均 R^2 系数。
    当你这样做时,LogisticRegression你会得到平均准确度,这就是你想要的。

第 1 点和第 2 点解释了使用LogisticRegression和使用cross_val_scoreon获得的结果LinearRegression
我还不确定第一个案例,如果我找到一个好的解释,我会更新我的帖子。我觉得这令人惊讶,因为您在计算准确性方面所犯的错误总是会低估结果。当然,除非这不是您运行的实际代码。

于 2015-08-23T11:31:12.767 回答