为了从可训练变量列表中删除一个变量,您可以首先通过以下方式访问该集合:
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
那里,trainable_collection
包含对可训练变量集合的引用。如果你从这个列表中弹出元素,例如trainable_collection.pop(0)
,你将从可训练变量中删除相应的变量,因此这个变量将不会被训练。
尽管这适用于pop
,但我仍在努力寻找一种正确使用remove
正确参数的方法,因此我们不依赖于变量的索引。
编辑:假设您有图中变量的名称(您可以通过检查图 protobuf 或使用 Tensorboard 更容易获得),您可以使用它来遍历可训练变量列表,然后删除来自可训练集合的变量。示例:假设我想要带有名称的变量"batch_normalization/gamma:0"
并且"batch_normalization/beta:0"
不被训练,但它们已经添加到TRAINABLE_VARIABLES
集合中。我能做的是:`
#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
#uses the attribute 'name' of the variable
if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
variables_to_remove.append(vari)
for rem in variables_to_remove:
trainable_collection.remove(rem)
` 这将成功地从集合中移除这两个变量,它们将不再被训练。