3

我想实现自动编码器(确切地说是堆叠卷积自动编码器)

在这里我想先对每一层进行预训练,然后再进行微调

所以我为每一层的权重创建了变量

前任。第一层的 W_1 = tf.Variable(initial_value, name,trainable=True etc)

我预训练了第一层的 W_1

然后我想预训练第二层的权重(W_2)

这里我应该使用 W_1 来计算第二层的输入。

但是 W_1 是可训练的,所以如果我直接使用 W_1,那么 tensorflow 可以一起训练 W_1。

所以我应该创建保持 W_1 值但不可训练的 W_1_out

老实说,我试图修改这个网站的代码

https://github.com/cmgreen210/TensorFlowDeepAutoencoder/blob/master/code/ae/autoencoder.py

在第 102 行,它通过以下代码创建变量

self[name_w + "_fixed"] = tf.Variable(tf.identity(self[name_w]),
                                            name=name_w + "_fixed",
                                            trainable=False)

但是它调用错误,因为它使用未初始化的值

我应该如何复制变量但使其不可训练来预训练下一层?

4

1 回答 1

3

不确定是否仍然相关,但无论如何我都会尝试。

一般来说,我在这种情况下会做以下事情:

  • 根据您正在构建的模型填充(默认)图,例如,对于第一个训练步骤,只需创建W1您提到的第一个卷积层。当您训练第一层时,您可以在训练完成后存储保存的模型,然后重新加载它并添加第二层所需的操作W2W1或者,您可以直接在代码中再次从头开始构建整个图形,然后为W2.

  • 如果您使用的是 Tensorflow 提供的恢复机制,您将拥有权重W1已经是预训练的优势。如果您不使用恢复机制,您将不得不W1手动设置权重,例如通过执行下面的代码段中显示的操作。

  • 然后,当您设置训练操作时,您可以将变量列表传递var_list给优化器,该列表明确告诉优化器更新哪些参数以最小化损失。如果将其设置为None(默认值),它只会使用它可以找到的tf.trainable_variables()内容,而这些内容又是所有tf.Variables可训练内容的集合。也可以检查这个答案,它基本上说的是同样的事情。
  • 使用var_list参数时,图形集合会派上用场。例如,您可以为要训练的每一层创建一个单独的图形集合。该集合将包含每一层的可训练变量,然后您可以很容易地检索所需的集合并将其作为var_list参数传递(参见下面的示例和/或上述链接文档中的注释)。

如何覆盖变量的值: name是要覆盖的变量的名称,value是适当大小和类型的数组,sess是会话:

variable = tf.get_default_graph().get_tensor_by_name(name)
sess.run(tf.assign(variable, value))

请注意,最后name需要一个额外的,所以例如,如果你的层的:0权重在示例中被称为应该是。'weights1'name'weights1:0'

将张量添加到自定义集合:使用以下几行:

tf.add_to_collection('layer1_tensors', weights1)
tf.add_to_collection('layer1_tensors', some_other_trainable_variable)

请注意,第一行创建了集合,因为它还不存在,第二行将给定的张量添加到现有集合中。

如何使用自定义集合:现在您可以执行以下操作:

# loss = some tensorflow op computing the loss
var_list = tf.get_collection_ref('layer1_tensors')
optim = tf.train.AdamOptimizer().minimize(loss=loss, var_list=var_list)

您也可以使用tf.get_collection('layer_tensors')which 将返回集合的副本。

当然,如果你不想做任何这些,你可以trainable=False在为你不想被训练的所有变量创建图表时使用,正如你在问题中暗示的那样。但是,我不太喜欢该选项,因为它要求您将布尔值传递给填充图形的函数,这很容易被忽略,因此容易出错。此外,即使您决定这样做,您仍然必须手动恢复不可训练的变量。

于 2017-02-06T14:04:31.770 回答