38

我在范围内创建了一个可训练变量。后来我进入同一个作用域,把作用域设置为reuse_variables,用来get_variable检索同一个变量。但是,我无法将变量的可训练属性设置为False. 我的get_variable线路是这样的:

weight_var = tf.get_variable('weights', trainable = False)

但变量'weights'仍在 的输出中tf.trainable_variables

我可以使用设置共享变量的trainable标志吗?Falseget_variable

我想这样做的原因是我试图在我的模型中重用从 VGG 网络预训练的低级过滤器,我想像以前一样构建图表,检索权重变量,并分配 VGG 过滤器值到权重变量,然后在接下来的训练步骤中保持它们不变。

4

4 回答 4

31

查看文档和代码后,我无法找到从TRAINABLE_VARIABLES.

这是发生的事情:

  • 第一次tf.get_variable('weights', trainable=True)调用时,变量被添加到列表中TRAINABLE_VARIABLES
  • 第二次调用tf.get_variable('weights', trainable=False),您得到相同的变量,但参数trainable=False无效,因为该变量已经存在于列表中TRAINABLE_VARIABLES(并且无法从那里删除它

第一个解决方案

调用minimize优化器的方法时(请参阅doc.),您可以将var_list=[...]作为参数传递给您想要优化器的变量。

例如,如果你想冻结除最后两层之外的所有 VGG 层,你可以将最后两层的权重传递给var_list.

第二种解决方案

您可以使用 atf.train.Saver()来保存变量并在以后恢复它们(请参阅本教程)。

  • 首先,您使用所有可训练变量训练整个 VGG 模型。您可以通过调用将它们保存在检查点文件中saver.save(sess, "/path/to/dir/model.ckpt")
  • 然后(在另一个文件中)用不可训练的变量训练第二个版本。您加载以前存储的变量saver.restore(sess, "/path/to/dir/model.ckpt")

或者,您可以决定只保存检查点文件中的一些变量。有关更多信息,请参阅文档

于 2016-05-19T15:19:16.833 回答
13

当您只想训练或优化预训练网络的某些层时,这就是您需要知道的。

TensorFlow 的minimize方法采用一个可选参数var_list,即要通过反向传播调整的变量列表。

如果不指定var_list,则图中的任何 TF 变量都可以由优化器进行调整。当您在 中指定一些变量时var_list,TF 将所有其他变量保持不变。

这是jonbruner和他的合作者使用的脚本示例。

tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

这会找到他们之前定义的所有变量名中包含“g_”的变量,将它们放入列表中,然后在它们上运行 ADAM 优化器。

您可以在Quora上找到相关答案

于 2018-04-13T13:09:03.450 回答
7

为了从可训练变量列表中删除一个变量,您可以首先通过以下方式访问该集合: trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) 那里,trainable_collection包含对可训练变量集合的引用。如果你从这个列表中弹出元素,例如trainable_collection.pop(0),你将从可训练变量中删除相应的变量,因此这个变量将不会被训练。

尽管这适用于pop,但我仍在努力寻找一种正确使用remove正确参数的方法,因此我们不依赖于变量的索引。

编辑:假设您有图中变量的名称(您可以通过检查图 protobuf 或使用 Tensorboard 更容易获得),您可以使用它来遍历可训练变量列表,然后删除来自可训练集合的变量。示例:假设我想要带有名称的变量"batch_normalization/gamma:0"并且"batch_normalization/beta:0" 被训练,但它们已经添加到TRAINABLE_VARIABLES集合中。我能做的是:`

#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
    #uses the attribute 'name' of the variable
    if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
        variables_to_remove.append(vari)
for rem in variables_to_remove:
    trainable_collection.remove(rem)

` 这将成功地从集合中移除这两个变量,它们将不再被训练。

于 2018-10-23T07:47:25.940 回答
0

您可以使用 tf.get_collection_ref 来获取集合的引用,而不是 tf.get_collection

于 2019-05-17T20:21:56.000 回答