2

我正在尝试编写一个基于 sklearn 的特征提取管道。我的管道代码想法可以分成几个部分

  1. 可能发生所有数据预处理(如果需要)的父类
from sklearn.base import BaseEstimator, TransformerMixin

class FeatureExtractor(BaseEstimator, TransformerMixin):
    """This is the parent class for all feature extractors."""
    def __init__(self, raw_data = {}):
        self.raw_data = raw_data

    def fit(self, X, y=None):
        return self
  1. 一种装饰器,可帮助定义特征提取的执行,用于智能地处理一个特征依赖于另一个特征的情况。
# A decorator to assign order of feature extraction within fearure extractor classes
def feature_order(order):
   def order_assignment(to_func):
       to_func.order = order
       return to_func
   return order_assignment
  1. 最后是所有特征提取发生的子类之一:
class ChilldFeatureExtractor1(FeatureExtractor):
   """This is the one of the child feature extractor class."""

   def __init__(self, raw_data = {}):
       super().__init__(raw_data)
       self.raw_data = raw_data

   @feature_order(1)
   def foo_plus_one(self):
       return self.raw_data['foo'] + 1

   # This feature extractor depends on value populated in previous feature extractor
   @feature_order(2)
   def foo_plus_one_plus_one(self):
       return self.raw_data['foo_plus_one'] + 1

   def transform(self):
       functions = sorted(
           #get a list of extractor functions with attribute order
           [
           getattr(self, field) for field in dir(self)
           if hasattr(getattr(self, field), "order")
           ],
           #sort the feature extractor functions by their order
           key = (lambda field: field.order)
           )

       for func in functions:
           feature_name = func.__name__
           feature_value = func()
           self.raw_data[feature_name] = feature_value

       return self.raw_data

测试此代码的一个小输入:

if __name__ == '__main__':
    raw_data = {'foo': 1, 'bar': 2}
    fe = ChilldFeatureExtractor1(raw_data)
    print(fe.transform())

给出错误:

Traceback (most recent call last):
  File "/Users/temporaryadmin/deleteme.py", line 55, in <module>
    print(fe.transform())
  File "/Users/temporaryadmin/deleteme.py", line 37, in transform
    [
  File "/Users/temporaryadmin/deleteme.py", line 39, in <listcomp>
    if hasattr(getattr(self, field), "order")
  File "/Users/temporaryadmin/opt/miniconda3/envs/voutopia/lib/python3.8/site-packages/sklearn/base.py", line 450, in _repr_html_
    raise AttributeError("_repr_html_ is only defined when the "
AttributeError: _repr_html_ is only defined when the 'display' configuration option is set to 'diagram'

但是,当我不继承基类中的 sklearn 类时,即。class FeatureExtractor():然后我得到正确的输出:

{'foo': 1, 'bar': 2, 'foo_plus_one': 2, 'foo_plus_one_plus_one': 3}

对此有任何指示吗?

4

2 回答 2

2

错误回溯表明哪里出了问题:在其 中列出了self一个属性,但尝试使用throws that访问它,如@maxskoryk 答案的源链接所示。_repr_html___dir__getattrValueError

一种解决方法是在getattr调用中提供默认值:

   def transform(self):
       functions = sorted(
           #get a list of extractor functions with attribute order
           [
               getattr(self, field, None) for field in dir(self)
               if hasattr(getattr(self, field, None), "order")
           ],
           #sort the feature extractor functions by their order
           key = (lambda field: field.order),
       )
       ...

您也可以只限制不以下划线开头的属性,或任何其他合理的方式来限制检查哪些属性。

于 2022-02-08T19:17:38.180 回答
2

在运行代码之前试试这个:

from sklearn import set_config
set_config(display='diagram')

发生这种情况是因为BaseEstimator该类具有_repr_hrml_取决于显示为“图表”()的属性。我假设该属性在某个时候被评估并抛出错误。

于 2022-02-08T11:07:46.530 回答