1

我需要修复 scikit-learn 估计器的参数值。我仍然需要能够更改估算器的所有其他参数,并在 scikit-learn 工具(例如 Pipelines 和 GridSearchCV)中使用估算器。

我试图定义一个继承自 scikit-learn 估计器的新类。例如,在这里我试图创建一个修复n_estimators=5RandomForestClassifier 的新类。

class FiveTreesClassifier(RandomForestClassifier):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.n_estimators = 5


fivetrees = FiveTreesClassifier()
randomforest = RandomForestClassifier(n_estimators=5)

# This passes.
assert fivetrees.n_estimators == randomforest.n_estimators
# This fails: the params of fivetrees is an empty dict.
assert fivetrees.get_params() == randomforest.get_params()

不可靠的事实get_params()意味着我不能在 Pipelines 和 GridSearchCV 中使用新的估计器(如此所述)。

我正在使用 scikit-learn 0.24.2,但我认为它实际上与新版本相同。

我更喜欢让我在修复超参数值的同时定义一个新类的答案。我也会接受使用其他技术的答案。我也将感谢我为什么应该/不应该这样做的详尽解释!

4

1 回答 1

1

您可以使用functools.partial

NewEstimator = partial(RandomForestClassifier, n_estimators=5)
new_estimator = NewEstimator()
于 2021-11-24T23:49:10.890 回答