几个选项:
1)如果您可接受的值是可枚举的,请使用以下Enum
类型:
from enum import Enum
from dataclasses import dataclass
class SomeValue(Enum):
foo = 1
bar = 2
@dataclass
class Config:
# possible values are 'foo' and 'bar'
some_value: SomeValue = SomeValue.foo
如果不需要花哨的逻辑来验证some_value
,这是我推荐的解决方案。
2) 如果您使用的是 yaml 文件,您可以使用 OmegaConf 注册自定义解析器:
# my_python_file.py
from omegaconf import OmegaConf
def check_some_value(value: str) -> str:
assert value in ("foo", "bar")
return value
OmegaConf.register_new_resolver("check_foo_bar", check_some_value)
@hydra.main(...)
...
if __name__ == "__main__":
main()
# my_yaml_file.yaml
some_value: ${check_foo_bar:foo}
当您cfg.some_value
在 python 代码中访问时,AssertionError
如果值与函数不一致,则会引发an check_some_value
。
3)配置组合完成后,您可以调用OmegaConf.to_object
创建数据类的实例。这意味着数据类的__post_init__
函数将被调用。
import hydra
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
@dataclass
class Config:
# possible values are 'foo' and 'bar'
some_value: str = "foo"
def __post_init__(self) -> None:
assert self.some_value in ("foo", "bar")
@hydra.main(config_path="configs", config_name="config")
def main(dict_cfg: DictConfg):
cfg: Config = OmegaConf.to_object(dict_cfg)
print(cfg)
if __name__ == "__main__":
main()