我正在使用 Hydra 来训练机器学习模型。它非常适合执行复杂的命令,例如python train.py data=MNIST batch_size=64 loss=l2. 但是,如果我想使用相同的参数运行经过训练的模型,我必须执行类似python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml. 然后我使用argparse加载之前的 yaml 并使用 compose API 来初始化 Hydra 环境。训练模型的路径是从 Hydra.yaml文件的路径推断出来的。如果我想修改其中一个参数,我必须添加其他argparse参数并运行类似python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml --batch_size 128. 然后,该代码使用命令行上指定的参数手动覆盖任何 Hydra 参数。




import hydra

@hydra.main(config_name="config", config_path="conf")
def main(cfg):
    # [training code using cfg.data, cfg.batch_size, cfg.loss etc.]
    # [code outputs model checkpoint to job folder generated by Hydra]


import argparse
import os
from hydra.experimental import initialize, compose

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int)
    # [other flags and parameters I may need to override]
    args = parser.parse_args()

    # Create the Hydra environment.
    cfg = compose(config_name=args.hydra_config)

    # Since checkpoints are stored next to the .hydra, we manually generate the path.
    checkpoint_dir = os.path.dirname(os.path.dirname(args.hydra_config))

    # Manually override any parameters which can be changed on the command line.
    batch_size = args.batch_size if args.batch_size else cfg.data.batch_size

    # [code which uses checkpoint_dir to load the model]
    # [code which uses both batch_size and params in cfg to set up the data etc.]



如果你想重新编写配置(听起来像你这样做,因为你添加了你想要覆盖的东西),我建议你使用 Compose API 并从作业输出目录中的覆盖文件中传入参数(下一个到存储的 config.yaml),但连接当前运行参数。


import os
from dataclasses import dataclass
from os.path import join
from typing import Optional

from omegaconf import OmegaConf

import hydra
from hydra import compose
from hydra.core.config_store import ConfigStore
from hydra.core.hydra_config import HydraConfig
from hydra.utils import to_absolute_path

# You can also use a yaml config file instead of this Structured Config
class Config:
    load_checkpoint: Optional[str] = None
    batch_size: int = 16
    loss: str = "l2"

cs = ConfigStore.instance()
cs.store(name="config", node=Config)

@hydra.main(config_path=".", config_name="config")
def my_app(cfg: Config) -> None:

    if cfg.load_checkpoint is not None:
        output_dir = to_absolute_path(cfg.load_checkpoint)
        original_overrides = OmegaConf.load(join(output_dir, ".hydra/overrides.yaml"))
        current_overrides = HydraConfig.get().overrides.task

        hydra_config = OmegaConf.load(join(output_dir, ".hydra/hydra.yaml"))
        # getting the config name from the previous job.
        config_name = hydra_config.hydra.job.config_name
        # concatenating the original overrides with the current overrides
        overrides = original_overrides + current_overrides
        # compose a new config from scratch
        cfg = compose(config_name, overrides=overrides)

    # train
    print("Running in ", os.getcwd())

if __name__ == "__main__":
~/tmp$ python train.py 
Running in  /home/omry/tmp/outputs/2021-04-19/21-23-13
load_checkpoint: null
batch_size: 16
loss: l2

~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13
Running in  /home/omry/tmp/outputs/2021-04-19/21-23-22
load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
batch_size: 16
loss: l2

~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13 batch_size=32
Running in  /home/omry/tmp/outputs/2021-04-19/21-23-28
load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
batch_size: 32
loss: l2
于 2021-04-20T04:26:41.633 回答