1

我尝试将我的应用程序用于hydra框架。我使用结构化配置模式,我想限制某些字段的可能值。有没有办法做到这一点?

这是我的代码:

我的应用程序.py:

import hydra


@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: str = "foo"


@hydra.main(config_path="configs", config_name="config")
def main(cfg: Config):
    print(cfg)


if __name__ == "__main__":
    main()

配置/config.yaml:

# value is incorrect.
# I need hydra to throw an exception in this case
some_value: "barrr"
4

1 回答 1

1

几个选项:

1)如果您可接受的值是可枚举的,请使用以下Enum类型:

from enum import Enum
from dataclasses import dataclass

class SomeValue(Enum):
    foo = 1
    bar = 2

@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: SomeValue = SomeValue.foo

如果不需要花哨的逻辑来验证some_value,这是我推荐的解决方案。

2) 如果您使用的是 yaml 文件,您可以使用 OmegaConf 注册自定义解析器

# my_python_file.py
from omegaconf import OmegaConf

def check_some_value(value: str) -> str:
    assert value in ("foo", "bar")
    return value

OmegaConf.register_new_resolver("check_foo_bar", check_some_value)

@hydra.main(...)
...

if __name__ == "__main__":
    main()
# my_yaml_file.yaml
some_value: ${check_foo_bar:foo}

当您cfg.some_value在 python 代码中访问时,AssertionError如果值与函数不一致,则会引发an check_some_value

3)配置组合完成后,您可以调用OmegaConf.to_object创建数据类的实例。这意味着数据类的__post_init__函数将被调用。

import hydra
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf

@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: str = "foo"

    def __post_init__(self) -> None:
        assert self.some_value in ("foo", "bar")

@hydra.main(config_path="configs", config_name="config")
def main(dict_cfg: DictConfg):
    cfg: Config = OmegaConf.to_object(dict_cfg)
    print(cfg)

if __name__ == "__main__":
    main()
于 2021-10-07T05:00:35.730 回答