2

我有 2 个子配置和一个具有这些子配置的主(?)配置。我设计了如下配置:

from dataclasses import dataclass, field

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig

from typing import Any, List

@dataclass
class DBConfig:
    host: str = "localhost"
    driver: str = MISSING
    port: int = MISSING


@dataclass
class MySQLConfig(DBConfig):
    driver: str = "mysql"
    port: int = 3306


@dataclass
class PostGreSQLConfig(DBConfig):
    driver: str = "postgresql"
    port: int = 5432
    timeout: int = 10


@dataclass
class ConnectionConfig:
    target: str = "app.my_class.MyClass"
    params: DBConfig = MISSING
    defaults: List[Any] = field(
        default_factory=lambda: [
            {
                "params": "mysql",      # I'd like to set mysql as a default
            }
        ]
    )



@dataclass
class AConfig:
    name: str = "foo"


@dataclass
class BConfig(AConfig):
    age: int = 10


@dataclass
class CConfig(AConfig):
    age: int = 20


@dataclass
class SomeOtherConfig:
    target: str = "app.my_class.MyClass2"
    params: AConfig = MISSING
    defaults: List[Any] = field(
        default_factory=lambda: [
            {
                "params": "bconfig",   # I'd like to set bconfig as a default
            }
        ]
    )



@dataclass
class Config:
    db_connection: ConnectionConfig = ConnectionConfig()
    some_other: SomeOtherConfig = SomeOtherConfig()


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    print(cfg.pretty())
    # connection = hydra.utils.instantiate(cfg)
    # print(connection)


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(
        name="config",
        node=Config,
    )
    cs.store(group="params", name="mysql", node=MySQLConfig)
    cs.store(group="params", name="postgresql", node=PostGreSQLConfig)

    cs.store(group="params", name="bconfig", node=BConfig)
    cs.store(group="params", name="cconfig", node=CConfig)

    my_app()

我在没有任何选项的情况下运行程序时所期望的:

db_connection:
    target: app.my_class.MyClass
    params:   
        host: localhost
        driver: mysql
        port: 3306   

some_other:
    target: app.my_class.MyClass2
    params:
        name: "foo"
        age: 10

但结果:

db_connection:
    target: app.my_class.MyClass
    params: ???
    defaults:
    - params: mysql
some_other:
    target: app.my_class.MyClass2
    params: ???
    defaults:
    - params: bconfig
4

1 回答 1

3

首先,从 Hydra 1.0 开始 - 默认列表仅在主要配置中受支持。以下是两个版本,第一个版本在您的示例中尽可能少地更改,第二个版本稍微清理一下。

示例 1:

from dataclasses import dataclass, field

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig

from typing import Any, List


@dataclass
class DBConfig:
    host: str = "localhost"
    driver: str = MISSING
    port: int = MISSING


@dataclass
class MySQLConfig(DBConfig):
    driver: str = "mysql"
    port: int = 3306


@dataclass
class PostGreSQLConfig(DBConfig):
    driver: str = "postgresql"
    port: int = 5432
    timeout: int = 10


@dataclass
class ConnectionConfig:
    target: str = "app.my_class.MyClass"
    params: DBConfig = MISSING


@dataclass
class AConfig:
    name: str = "foo"


@dataclass
class BConfig(AConfig):
    age: int = 10


@dataclass
class CConfig(AConfig):
    age: int = 20


@dataclass
class SomeOtherConfig:
    target: str = "app.my_class.MyClass2"
    params: AConfig = MISSING


@dataclass
class Config:
    db_connection: ConnectionConfig = ConnectionConfig()
    some_other: SomeOtherConfig = SomeOtherConfig()
    defaults: List[Any] = field(
        default_factory=lambda: [
            {"db_connection/params": "mysql"},
            {"some_other/params": "bconfig"},
        ]
    )


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    print(cfg.pretty())


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(
        name="config", node=Config,
    )
    cs.store(group="db_connection/params", name="mysql", node=MySQLConfig)
    cs.store(group="db_connection/params", name="postgresql", node=PostGreSQLConfig)

    cs.store(group="some_other/params", name="bconfig", node=BConfig)
    cs.store(group="some_other/params", name="cconfig", node=CConfig)

    my_app()

示例 2:

from dataclasses import dataclass, field

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig
from hydra.types import ObjectConf
from typing import Any, List


@dataclass
class DBConfig:
    host: str = "localhost"
    driver: str = MISSING
    port: int = MISSING


@dataclass
class MySQLConfig(DBConfig):
    driver: str = "mysql"
    port: int = 3306


@dataclass
class PostGreSQLConfig(DBConfig):
    driver: str = "postgresql"
    port: int = 5432
    timeout: int = 10


@dataclass
class AConfig:
    name: str = "foo"


@dataclass
class BConfig(AConfig):
    age: int = 10


@dataclass
class CConfig(AConfig):
    age: int = 20


defaults = [{"db_connection": "mysql"}, {"some_other": "bconfig"}]


@dataclass
class Config:
    db_connection: ObjectConf = MISSING
    some_other: ObjectConf = MISSING
    defaults: List[Any] = field(default_factory=lambda: defaults)


cs = ConfigStore.instance()
cs.store(name="config", node=Config)
cs.store(
    group="db_connection",
    name="mysql",
    node=ObjectConf(target="MySQL", params=MySQLConfig),
)
cs.store(
    group="db_connection",
    name="postgresql",
    node=ObjectConf(target="PostgeSQL", params=PostGreSQLConfig),
)
cs.store(
    group="some_other",
    name="bconfig",
    node=ObjectConf(target="ClassB", params=BConfig()),
)
cs.store(
    group="some_other",
    name="cconfig",
    node=ObjectConf(target="ClassC", params=AConfig()),
)


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    print(cfg.pretty())


if __name__ == "__main__":
    my_app()

于 2020-07-29T07:12:26.090 回答