我需要修复 scikit-learn 估计器的参数值。我仍然需要能够更改估算器的所有其他参数,并在 scikit-learn 工具(例如 Pipelines 和 GridSearchCV)中使用估算器。
我试图定义一个继承自 scikit-learn 估计器的新类。例如,在这里我试图创建一个修复n_estimators=5
RandomForestClassifier 的新类。
class FiveTreesClassifier(RandomForestClassifier):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.n_estimators = 5
fivetrees = FiveTreesClassifier()
randomforest = RandomForestClassifier(n_estimators=5)
# This passes.
assert fivetrees.n_estimators == randomforest.n_estimators
# This fails: the params of fivetrees is an empty dict.
assert fivetrees.get_params() == randomforest.get_params()
不可靠的事实get_params()
意味着我不能在 Pipelines 和 GridSearchCV 中使用新的估计器(如此处所述)。
我正在使用 scikit-learn 0.24.2,但我认为它实际上与新版本相同。
我更喜欢让我在修复超参数值的同时定义一个新类的答案。我也会接受使用其他技术的答案。我也将感谢我为什么应该/不应该这样做的详尽解释!