4

I've noticed that the new Estimator API automatically saves checkpoints during the training and automatically restart from the last checkpoint when training was interrupted. Unfortunately it seems it only keeps last 5 check points.

Do you know how to control the number of checkpoints that are kept during the training?

4

2 回答 2

8

Tensorflow tf.estimator.Estimator作为一个可选参数,它config可以是一个tf.estimator.RunConfig对象来配置运行时设置。您可以通过以下方式实现:

# Change maximum number checkpoints to 25
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(keep_checkpoint_max=25)

# Build your estimator
estimator = tf.estimator.Estimator(model_fn,
                                   model_dir=job_dir,
                                   config=run_config,
                                   params=None)

config参数在扩展的所有类(DNNClassifierDNNLinearCombinedClassifierLinearClassifier等)中都可用estimator.Estimator

于 2017-12-30T07:59:47.633 回答
0

作为旁注,我想补充一点,在 TensorfFlow2 中,情况要简单一些。要保留一定数量的检查点文件,您可以修改model_main_tf2.py源代码。首先,您可以添加并定义一个整数标志为

# Keep last 25 checkpoints
flags.DEFINE_integer('checkpoint_max_to_keep', 25,
                     'Integer defining how many checkpoint files to keep.')

然后在调用中使用这个预定义的值model_lib_v2.train_loop

# Ensure training loop keeps last 25 checkpoints
model_lib_v2.train_loop(...,
                        checkpoint_max_to_keep=FLAGS.checkpoint_max_to_keep,
                        ...)

上面的符号...表示其他选项model_lib_v2.train_loop

于 2021-09-08T10:03:34.257 回答