sklearn 有很好的文档。这是一个完整的示例数据集:http ://scikit-learn.org/stable/auto_examples/linear_model/plot_ols.html
您遇到的最大问题是您的数据集。喜欢你的代码:
y_true = ['CRIM', 'RM', 'PTRATIO']
y_pred = ['PRICE']
这甚至不是真实数据,它只是 2 个字符串标签列表,所以这当然行不通:
mean_squared_error(y_true, y_pred)
从我发布的示例中,您可以尝试使用这种“hello world”类型代码(使用现有数据集),以确保代码正常工作,然后您需要做的就是用您自己的数据替换数据集。如您所见,大部分代码专门用于准备数据,以便正确加载到线性回归函数中:
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, linear_model
# Load the diabetes dataset
diabetes = datasets.load_diabetes()
# Use only one feature
diabetes_X = diabetes.data[:, np.newaxis, 2]
# Split the data into training/testing sets
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]
# Split the targets into training/testing sets
diabetes_y_train = diabetes.target[:-20]
diabetes_y_test = diabetes.target[-20:]
# Create linear regression object
regr = linear_model.LinearRegression()
# Train the model using the training sets
regr.fit(diabetes_X_train, diabetes_y_train)
print("Mean squared error: %.2f" % np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2))