1

我正在使用其他人的 Scikit-Learn 代码来构建预测工具。原始代码工作得很好,但我需要添加sample_weight到预测工具中。

在不同的文档中搜索了解决方案后,我发现主要问题是 Scikit-Learn 中的管道不能sample_weight很好地支持。


# creating pipeline
pipeline = make_pipeline(preprocessing.StandardScaler(), RandomForestRegressor(n_estimators=100))

hyperparameters = {'randomforestregressor__max_features': ['auto'],
                   'randomforestregressor__max_depth': [None]   }


clf = GridSearchCV(pipeline, hyperparameters, cv=10, verbose=10)

clf.fit(X_train, Y_train
        #        , fit_params={'sample_weight': W_train}
        # , fit_params={'sample_weight':W_train}
        # , **{'randomforestregressor__sample_weight': W_train}
        )

# testing model
pred = clf.predict(X_test)
r2_score(Y_test, pred)
mean_squared_error(Y_test, pred)
print(r2_score(Y_test, pred))
print(mean_squared_error(Y_test, pred))


# 保存模型以便将来使用
joblib.dump(clf, 'rf_regressor.pkl')

我试图插入sample_weight不同的位置,但都显示失败。谁能帮我告诉我在哪里插入sample_weightwith pipeline,或者在不使用的情况下实现这些步骤(包括sample_weightpipeline

4

1 回答 1

0

我认为,问题必须与W_train因为在下面找到我的示例和您的代码,它工作得很好。

from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import GridSearchCV

# creating pipeline
pipeline = make_pipeline(StandardScaler(),
                         RandomForestRegressor(n_estimators=100))


from sklearn.datasets import load_diabetes

X, y = load_diabetes(return_X_y=True)
hyperparameters = {'randomforestregressor__max_features': ['auto'],
                   'randomforestregressor__max_depth': [None]   }


clf = GridSearchCV(pipeline, hyperparameters, cv=10, verbose=10)

clf.fit(X , y,
        **{'randomforestregressor__sample_weight': np.random.choice([0,2,3,5],size=len(X))})

#
Fitting 10 folds for each of 1 candidates, totalling 10 fits
[CV] randomforestregressor__max_depth=None, randomforestregressor__max_features=auto 
[CV]  randomforestregressor__max_depth=None, randomforestregressor__max_features=auto, score=0.385, total=   0.2s
...
于 2019-07-11T14:27:57.753 回答