0

在使用实例化时,有没有办法为目标定义一个 2 的幂的参数?例如:

from sklearn.feature_extraction import HashingVectorizer

vec = HashingVectorizer(n_features=2**18)
vec.transform(["a quick fox"])
<1x262144 sparse matrix of type '<class 'numpy.float64'>'
        with 2 stored elements in Compressed Sparse Row format>

正如预期的那样,输出是一个稀疏向量,其形状为 (1, 262144),相当于 2**18。

但是,在配置文件中,您不能使用该值2**18,因为它是作为字符串传入的。

配置.yaml

vec:
  _target_: sklearn.feature_extraction.text.HashingVectorizer
  n_features: 2**18

测试.py

import hydra
import hydra.utils as hu


@hydra.main(config_path='conf', config_name='config')
def main(cfg):
    vec = hu.instantiate(cfg.vec)
    vec.transform(['Erroneous Monk'])


if __name__ == "__main__":
    main()

运行此示例,您将获得以下信息:

python test.py
...
TypeError: n_features must be integral, got '2**18' (<class 'str'>).

有没有办法通知 hydra 该值不应被视为字符串?

4

1 回答 1

1

OmegaConf(底层配置库)目前不支持算术表达式。您可以使用自定义解析器来实现某些东西。例如,您可以通过名称 pow 注册一个自定义解析器,它将在两个输入上调用 Python 幂函数。

import hydra
import hydra.utils as hu
from omegaconf import OmegaConf

# register the resolver before you access the config field.
OmegaConf.register_new_resolver("pow", lambda x,y: x**y)

@hydra.main(config_path='conf', config_name='config')
def main(cfg):
    vec = hu.instantiate(cfg.vec)
    vec.transform(['Erroneous Monk'])


if __name__ == "__main__":
    main()

您的配置可以定义为:

vec:
  _target_: sklearn.feature_extraction.text.HashingVectorizer 
  n_features: ${pow:2,18}
于 2021-10-04T21:07:43.633 回答