1

我正在使用 TF-Slim 从预训练模型微调我的模型。使用的时候create_train_op发现它有一个参数是variables_to_train。在某些教程中,它使用如下标志:

   all_trainable = [v for v in tf.trainable_variables()]
   trainable     = [v for v in all_trainable]
   train_op      = slim.learning.create_train_op(
        opt,
        global_step=global_step,
        variables_to_train=trainable,
        summarize_gradients=True)

但是在官方的TF-Slim中,并没有使用

   all_trainable = [v for v in tf.trainable_variables()]
   trainable     = [v for v in all_trainable]
   train_op      = slim.learning.create_train_op(
        opt,
        global_step=global_step,            
        summarize_gradients=True)

那么,使用和不使用有什么不同variables_to_train呢?

4

1 回答 1

2

Your two example both do the same thing. You train all trainable variables that occur in your graph. With the parameter variables_to_train you can define which variables should be updated during your training.

A use case for this is when you have pre-trained stuff like word embedding that you don't want to train in your model. With

train_vars = [v for v in tf.trainable_variables() if "embeddings" not in v.name]
train_op      = slim.learning.create_train_op(
    opt,
    global_step=global_step,
    variables_to_train=train_vars,
    summarize_gradients=True)

you can exclude all variables from training that contain "embeddings" in their name. If you simply want to train all variables, you don't have to define train_vars and you can create the train op without the parameter variables_to_train.

于 2018-07-21T09:30:38.837 回答