2

我正在使用 Hydra 来训练机器学习模型。它非常适合执行复杂的命令,例如python train.py data=MNIST batch_size=64 loss=l2. 但是,如果我想使用相同的参数运行经过训练的模型,我必须执行类似python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml. 然后我使用argparse加载之前的 yaml 并使用 compose API 来初始化 Hydra 环境。训练模型的路径是从 Hydra.yaml文件的路径推断出来的。如果我想修改其中一个参数,我必须添加其他argparse参数并运行类似python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml --batch_size 128. 然后,该代码使用命令行上指定的参数手动覆盖任何 Hydra 参数。

这样做的正确方法是什么?

我当前的代码如下所示:

train.py

import hydra

@hydra.main(config_name="config", config_path="conf")
def main(cfg):
    # [training code using cfg.data, cfg.batch_size, cfg.loss etc.]
    # [code outputs model checkpoint to job folder generated by Hydra]
main()

reconstruct.py

import argparse
import os
from hydra.experimental import initialize, compose

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('hydra_config')
    parser.add_argument('--batch_size', type=int)
    # [other flags and parameters I may need to override]
    args = parser.parse_args()

    # Create the Hydra environment.
    initialize()
    cfg = compose(config_name=args.hydra_config)

    # Since checkpoints are stored next to the .hydra, we manually generate the path.
    checkpoint_dir = os.path.dirname(os.path.dirname(args.hydra_config))

    # Manually override any parameters which can be changed on the command line.
    batch_size = args.batch_size if args.batch_size else cfg.data.batch_size

    # [code which uses checkpoint_dir to load the model]
    # [code which uses both batch_size and params in cfg to set up the data etc.]

这是我第一次发帖,所以如果我需要澄清什么,请告诉我。

4

1 回答 1

1

如果您想按原样加载以前的配置而不更改它,请使用OmegaConf.load(file_path).

如果你想重新编写配置(听起来像你这样做,因为你添加了你想要覆盖的东西),我建议你使用 Compose API 并从作业输出目录中的覆盖文件中传入参数(下一个到存储的 config.yaml),但连接当前运行参数。

这个脚本似乎正在完成这项工作:

import os
from dataclasses import dataclass
from os.path import join
from typing import Optional

from omegaconf import OmegaConf

import hydra
from hydra import compose
from hydra.core.config_store import ConfigStore
from hydra.core.hydra_config import HydraConfig
from hydra.utils import to_absolute_path

# You can also use a yaml config file instead of this Structured Config
@dataclass
class Config:
    load_checkpoint: Optional[str] = None
    batch_size: int = 16
    loss: str = "l2"


cs = ConfigStore.instance()
cs.store(name="config", node=Config)


@hydra.main(config_path=".", config_name="config")
def my_app(cfg: Config) -> None:

    if cfg.load_checkpoint is not None:
        output_dir = to_absolute_path(cfg.load_checkpoint)
        original_overrides = OmegaConf.load(join(output_dir, ".hydra/overrides.yaml"))
        current_overrides = HydraConfig.get().overrides.task

        hydra_config = OmegaConf.load(join(output_dir, ".hydra/hydra.yaml"))
        # getting the config name from the previous job.
        config_name = hydra_config.hydra.job.config_name
        # concatenating the original overrides with the current overrides
        overrides = original_overrides + current_overrides
        # compose a new config from scratch
        cfg = compose(config_name, overrides=overrides)

    # train
    print("Running in ", os.getcwd())
    print(OmegaConf.to_yaml(cfg))


if __name__ == "__main__":
    my_app()
~/tmp$ python train.py 
Running in  /home/omry/tmp/outputs/2021-04-19/21-23-13
load_checkpoint: null
batch_size: 16
loss: l2

~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13
Running in  /home/omry/tmp/outputs/2021-04-19/21-23-22
load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
batch_size: 16
loss: l2

~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13 batch_size=32
Running in  /home/omry/tmp/outputs/2021-04-19/21-23-28
load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
batch_size: 32
loss: l2
于 2021-04-20T04:26:41.633 回答