您可以使用参考张量的动态形状,而不是静态形状。
通常,当您使用该conv2d_transpose
操作时,您会“上采样”一层以获得网络中另一个张量的某种形状。
例如,如果您想复制input_tensor
张量的形状,您可以执行以下操作:
import tensorflow as tf
input_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 16, 16, 3])
# static shape
print(input_tensor.shape)
conv_filter = tf.get_variable(
'conv_filter', shape=[2, 2, 3, 6], dtype=tf.float32)
conv1 = tf.nn.conv2d(
input_tensor, conv_filter, strides=[1, 2, 2, 1], padding='SAME')
# static shape
print(conv1.shape)
deconv_filter = tf.get_variable(
'deconv_filter', shape=[2, 2, 6, 3], dtype=tf.float32)
deconv = tf.nn.conv2d_transpose(
input_tensor,
filter=deconv_filter,
# use tf.shape to get the dynamic shape of the tensor
# know at RUNTIME
output_shape=tf.shape(input_tensor),
strides=[1, 2, 2, 1],
padding='SAME')
print(deconv.shape)
程序输出:
(?, 16, 16, 3)
(?, 8, 8, 6)
(?, ?, ?, ?)
正如你所看到的,最后一个形状在编译时是完全未知的,因为我正在conv2d_transpose
使用操作结果设置输出形状tf.shape
,它返回,因此它的值可以在运行时改变。