7

假设我已经训练了一个 Tensorflow Estimator:

estimator = tf.contrib.learn.Estimator(
  model_fn=model_fn,
  model_dir=MODEL_DIR,
  config=some_config)

我将它与一些火车数据相匹配:

estimator.fit(input_fn=input_fn_train, steps=None)

这个想法是模型适合我的 MODEL_DIR。该文件夹包含一个检查点和几个文件.meta.index.

这完美地工作。我想用我的函数做一些预测:

estimator = tf.contrib.Estimator(
  model_fn=model_fn,
  model_dir=MODEL_DIR,
  config=some_config)

predictions = estimator.predict(input_fn=input_fn_test)

我的解决方案完美运行,但有一个很大的缺点:您需要知道 model_fn,这是我在 Python 中定义的模型。但是如果我通过在我的 Python 代码中添加一个密集层来更改模型,那么这个模型对于 MODEL_DIR 中保存的数据是不正确的,从而导致不正确的结果:

NotFoundError (see above for traceback): Key xxxx/dense/kernel not found in checkpoint

我该如何应对?如何加载我的模型/估计器,以便我可以对一些新数据进行预测?如何从 MODEL_DIR 加载 model_fn 或估算器?

4

1 回答 1

1

避免糟糕的修复

仅当模型和检查点兼容时,才能从检查点恢复模型的状态。例如,假设您训练了一个DNNClassifier包含两个隐藏层的 Estimator,每个隐藏层有 10 个节点:

classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

在训练之后(因此,在 中创建检查点之后models/iris),假设您将每个隐藏层中的神经元数量从 10 个更改为 20 个,然后尝试重新训练模型:

classifier2 = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[20, 20],  # Change the number of neurons in the model.
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

由于检查点中的状态与 中描述的模型不兼容,因此classifier2重新训练失败并出现以下错误:

...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]

要运行您训练和比较模型的略微不同版本的实验,请保存创建每个 的代码的副本model_dir,可能通过为每个版本创建单独的 git 分支。这种分离将使您的检查点保持可恢复性。

从 tensorflow 检查点文档复制。

https://www.tensorflow.org/get_started/checkpoints

希望可以帮助你。

于 2018-06-11T07:45:49.130 回答