我正在使用 Hydra 来训练机器学习模型。它非常适合执行复杂的命令,例如python train.py data=MNIST batch_size=64 loss=l2
. 但是,如果我想使用相同的参数运行经过训练的模型,我必须执行类似python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml
. 然后我使用argparse
加载之前的 yaml 并使用 compose API 来初始化 Hydra 环境。训练模型的路径是从 Hydra.yaml
文件的路径推断出来的。如果我想修改其中一个参数,我必须添加其他argparse
参数并运行类似python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml --batch_size 128
. 然后,该代码使用命令行上指定的参数手动覆盖任何 Hydra 参数。
这样做的正确方法是什么?
我当前的代码如下所示:
train.py
:
import hydra
@hydra.main(config_name="config", config_path="conf")
def main(cfg):
# [training code using cfg.data, cfg.batch_size, cfg.loss etc.]
# [code outputs model checkpoint to job folder generated by Hydra]
main()
reconstruct.py
:
import argparse
import os
from hydra.experimental import initialize, compose
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('hydra_config')
parser.add_argument('--batch_size', type=int)
# [other flags and parameters I may need to override]
args = parser.parse_args()
# Create the Hydra environment.
initialize()
cfg = compose(config_name=args.hydra_config)
# Since checkpoints are stored next to the .hydra, we manually generate the path.
checkpoint_dir = os.path.dirname(os.path.dirname(args.hydra_config))
# Manually override any parameters which can be changed on the command line.
batch_size = args.batch_size if args.batch_size else cfg.data.batch_size
# [code which uses checkpoint_dir to load the model]
# [code which uses both batch_size and params in cfg to set up the data etc.]
这是我第一次发帖,所以如果我需要澄清什么,请告诉我。