我已经检查了涵盖该主题的其他问题,例如this、this、this、this和this以及一些很棒的博客文章blog1、blog2和blog3(感谢各自的作者),但没有成功。
我想要做的是转换其值低于某个阈值的行X
,但仅转换那些与目标y
(y != 9
)中的某些特定类相对应的行。阈值是根据其他类 ( y == 9
) 计算的。但是,我在理解如何正确实施这一点时遇到了问题。
由于我想对此进行参数调整和交叉验证,我将不得不使用管道进行转换。我的自定义变压器类如下所示。请注意,我没有包括在内,因为我认为我需要在函数TransformerMixin
中考虑到。y
fit_transform()
class CustomTransformer(BaseEstimator):
def __init__(self, percentile=.90):
self.percentile = percentile
def fit(self, X, y):
# Calculate thresholds for each column
thresholds = X.loc[y == 9, :].quantile(q=self.percentile, interpolation='linear').to_dict()
# Store them for later use
self.thresholds = thresholds
return self
def transform(self, X, y):
# Create a copy of X
X_ = X.copy(deep=True)
# Replace values lower than the threshold for each column
for p in self.thresholds:
X_.loc[y != 9, p] = X_.loc[y != 9, p].apply(lambda x: 0 if x < self.thresholds[p] else x)
return X_
def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X, y)
然后将其馈入管道和后续的 GridSearchCV。我在下面提供了一个工作示例。
imports...
# Create some example data to work with
random.seed(12)
target = [randint(1, 8) for _ in range(60)] + [9]*40
shuffle(target)
example = pd.DataFrame({'feat1': sample(range(50, 200), 100),
'feat2': sample(range(10, 160), 100),
'target': target})
example_x = example[['feat1', 'feat2']]
example_y = example['target']
# Create a final nested pipeline where the data pre-processing steps and the final estimator are included
pipeline = Pipeline(steps=[('CustomTransformer', CustomTransformer(percentile=.90)),
('estimator', RandomForestClassifier())])
# Parameter tuning with GridSearchCV
p_grid = {'estimator__n_estimators': [50, 100, 200]}
gs = GridSearchCV(pipeline, p_grid, cv=10, n_jobs=-1, verbose=3)
gs.fit(example_x, example_y)
上面的代码给了我以下错误。
/opt/anaconda3/envs/Python37/lib/python3.7/concurrent/futures/_base.py in __get_result(self)
382 def __get_result(self):
383 if self._exception:
--> 384 raise self._exception
385 else:
386 return self._result
TypeError: transform() missing 1 required positional argument: 'y'
我还尝试了其他方法,例如fit()
在transform()
. 但是,由于交叉验证期间的训练索引和测试索引不同,因此在替换 中的值时会产生索引错误transform()
。
那么,有没有聪明的方法来解决这个问题呢?