2

我已经使用 tf.layers.conv2d 层构建了一个自动编码器,并希望分阶段对其进行训练。也就是先训练外层,然后是中间层,然后是内层。我知道使用 tf.nn.conv2d 可以做到这一点,因为权重是使用 tf.get_variable 声明的,但我认为使用 tf.layers.conv2d 也应该可以做到这一点。

如果我输入一个与原始图不同的新变量范围来更改卷积层的输入(即在第 1 阶段跳过内层),我将无法重用权重。如果我不输入新的变量范围,我将无法冻结我不想在此阶段训练的权重。

基本上我正在尝试使用来自 Aurélien Géron 的训练方法https://github.com/ageron/handson-ml/blob/master/15_autoencoders.ipynb

除了我想使用 cnn 而不是密集层。怎么做?

4

2 回答 2

6

无需手动创建变量。这同样有效:

import tensorflow as tf

inputs_1 = tf.placeholder(tf.float32, (None, 512, 512, 3), name='inputs_1')
inputs_2 = tf.placeholder(tf.float32, (None, 512, 512, 3), name='inputs_2')

with tf.variable_scope('conv'):
    out_1 = tf.layers.conv2d(inputs_1, 32, [3, 3], name='conv_1')

with tf.variable_scope('conv', reuse=True):
    out_2 = tf.layers.conv2d(inputs_2, 32, [3, 3], name='conv_1')

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print(tf.trainable_variables())

如果您给出tf.layers.conv2d相同的名称,它将使用相同的权重(假设reuse=True,否则将有 a ValueError)。

在 Tesorflow 2.0 中: tf.layers被 keras 层替换,其中变量通过使用相同的层对象被重用:

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           input_shape=(512, 512, 3)), 
])

@tf.function 
def f1(x):
    return model(x)

@tf.function 
def f2(x):
    return model(x)

两者都f1f2使用具有相同变量的层

于 2018-04-26T17:12:44.393 回答
1

我建议设置它有点不同。我不会使用 tf.layers.conv2d,而是使用对 tf.get_variable() 的调用显式地制作权重,然后将这些权重与对 tf.nn.conv2d() 的调用一起使用。这样,您就不会黑箱化变量创建,并且可以轻松引用它们。这也是准确了解网络中正在发生的事情的好方法,因为您手动为每组权重编写了形状!

示例(未经测试)代码:

inputs = tf.placeholder(tf.float32, (batch_size, 512, 512, 3), name='inputs')
weights = tf.get_variable(name='weights', shape=[5, 5, 3, 16], dtype=tf.float32)

with tf.variable_scope("convs"):
    hidden_layer_1 = tf.nn.conv2d(input=inputs, filter=weights, stride=[1, 1, 1, 1], padding="SAME")
with tf.variable_scope("convs", reuse=True):
    hidden_layer_2 = tf.nn.conv2d(input=hidden_layer_1, filter=weights,stride=[1, 1, 1, 1], padding="SAME"

这会创建卷积权重并将其两次应用于您的输入。我还没有测试过这段代码,所以可能有错误,但它是关于它应该看起来如何。此处引用变量共享,此处引用 tf.nn.conv2d

希望这会有所帮助!我会更彻底,但我不知道你的代码是什么样的。

于 2018-04-25T20:03:01.737 回答