2

我一直无法弄清楚如何通过新的 TF Estimator API使用迁移学习/最后一层再训练。

Estimator需要一个model_fn包含网络架构以及培训和评估操作的文件,如文档所定义。一个model_fn使用 CNN 架构的例子在这里

如果我想重新训练最后一层,例如 inception 架构,我不确定是否需要在 this 中指定整个模型model_fn,然后加载预训练的权重,或者是否有办法使用在“传统”方法中保存的图表(示例here)。

这已作为一个问题提出,但仍然是开放的,我不清楚答案。

4

1 回答 1

2

可以在模型定义期间加载元图并使用 SessionRunHook 从 ckpt 文件加载权重。

def model(features, labels, mode, params):
    # Create the graph here

    return tf.estimator.EstimatorSpec(mode, 
            predictions,
            loss,
            train_op,
            training_hooks=[RestoreHook()])

SessionRunHook 可以是:

class RestoreHook(tf.train.SessionRunHook):

    def after_create_session(self, session, coord=None):
        if session.run(tf.train.get_or_create_global_step()) == 0:
            # load weights here

这样,权重在第一步加载并在模型检查点的训练过程中保存。

于 2018-01-11T16:13:09.373 回答