假设我已经训练了一个 Tensorflow Estimator:
estimator = tf.contrib.learn.Estimator(
model_fn=model_fn,
model_dir=MODEL_DIR,
config=some_config)
我将它与一些火车数据相匹配:
estimator.fit(input_fn=input_fn_train, steps=None)
这个想法是模型适合我的 MODEL_DIR。该文件夹包含一个检查点和几个文件.meta
和.index
.
这完美地工作。我想用我的函数做一些预测:
estimator = tf.contrib.Estimator(
model_fn=model_fn,
model_dir=MODEL_DIR,
config=some_config)
predictions = estimator.predict(input_fn=input_fn_test)
我的解决方案完美运行,但有一个很大的缺点:您需要知道 model_fn,这是我在 Python 中定义的模型。但是如果我通过在我的 Python 代码中添加一个密集层来更改模型,那么这个模型对于 MODEL_DIR 中保存的数据是不正确的,从而导致不正确的结果:
NotFoundError (see above for traceback): Key xxxx/dense/kernel not found in checkpoint
我该如何应对?如何加载我的模型/估计器,以便我可以对一些新数据进行预测?如何从 MODEL_DIR 加载 model_fn 或估算器?