2

tf.nn.conv2d_transpose 的文档说:

tf.nn.conv2d_transpose(
    value,
    filter,
    output_shape,
    strides,
    padding='SAME',
    data_format='NHWC',
    name=None
)

output_shape 参数需要一个一维张量,指定此操作输出的张量的形状。在这里,由于我的 conv-net 部分完全建立在动态 batch_length 占位符上,我似乎无法batch_size为这个操作的 output_shape 的静态要求提供解决方法。

网上有很多关于这个的讨论,但是,我找不到任何可靠的解决方案。它们中的大多数都是global_batch_size定义了变量的hacky。我想知道这个问题的最佳解决方案。这个训练有素的模型将作为部署服务提供。

4

4 回答 4

3

您可以使用参考张量的动态形状,而不是静态形状。

通常,当您使用该conv2d_transpose操作时,您会“上采样”一层以获得网络中另一个张量的某种形状。

例如,如果您想复制input_tensor张量的形状,您可以执行以下操作:

import tensorflow as tf

input_tensor = tf.placeholder(dtype=tf.float32, shape=[None, 16, 16, 3])
# static shape
print(input_tensor.shape)

conv_filter = tf.get_variable(
    'conv_filter', shape=[2, 2, 3, 6], dtype=tf.float32)
conv1 = tf.nn.conv2d(
    input_tensor, conv_filter, strides=[1, 2, 2, 1], padding='SAME')
# static shape
print(conv1.shape)

deconv_filter = tf.get_variable(
    'deconv_filter', shape=[2, 2, 6, 3], dtype=tf.float32)

deconv = tf.nn.conv2d_transpose(
    input_tensor,
    filter=deconv_filter,
    # use tf.shape to get the dynamic shape of the tensor
    # know at RUNTIME
    output_shape=tf.shape(input_tensor),
    strides=[1, 2, 2, 1],
    padding='SAME')
print(deconv.shape)

程序输出:

(?, 16, 16, 3)
(?, 8, 8, 6)
(?, ?, ?, ?)

正如你所看到的,最后一个形状在编译时是完全未知的,因为我正在conv2d_transpose使用操作结果设置输出形状tf.shape,它返回,因此它的值可以在运行时改变。

于 2017-10-23T09:49:37.143 回答
2

您可以使用以下代码根据该层的输入 ( input ) 和该层的输出数量 ( num_outputs ) 计算tf.nn.conv2d_transpose的输出形状参数。当然,您有过滤器大小、填充、步幅和 data_format。

def calculate_output_shape(input, filter_size_h, filter_size_w, 
    stride_h, stride_w, num_outputs, padding='SAME', data_format='NHWC'):

    #calculation of the output_shape:
    if data_format == "NHWC":
        input_channel_size = input.get_shape().as_list()[3]
        input_size_h = input.get_shape().as_list()[1]
        input_size_w = input.get_shape().as_list()[2]
        stride_shape = [1, stride_h, stride_w, 1]
        if padding == 'VALID':
            output_size_h = (input_size_h - 1)*stride_h + filter_size_h
            output_size_w = (input_size_w - 1)*stride_w + filter_size_w
        elif padding == 'SAME':
            output_size_h = (input_size_h - 1)*stride_h + 1
            output_size_w = (input_size_w - 1)*stride_w + 1
        else:
            raise ValueError("unknown padding")

        output_shape = tf.stack([tf.shape(input)[0], 
                            output_size_h, output_size_w, 
                            num_outputs])
    elif data_format == "NCHW":
        input_channel_size = input.get_shape().as_list()[1]
        input_size_h = input.get_shape().as_list()[2]
        input_size_w = input.get_shape().as_list()[3]
        stride_shape = [1, 1, stride_h, stride_w]
        if padding == 'VALID':
            output_size_h = (input_size_h - 1)*stride_h + filter_size_h
            output_size_w = (input_size_w - 1)*stride_w + filter_size_w
        elif padding == 'SAME':
            output_size_h = (input_size_h - 1)*stride_h + 1
            output_size_w = (input_size_w - 1)*stride_w + 1
        else:
            raise ValueError("unknown padding")

        output_shape = tf.stack([tf.shape(input)[0], 
                                output_size_h, output_size_w, num_outputs])
    else:
        raise ValueError("unknown data_format")

    return output_shape
于 2018-08-03T01:29:58.913 回答
1

您可以使用 -1 的值来替换 的确切值batch_size。考虑下面的示例,我将形状为 (16, 16, 3) 的可变批量大小的输入张量转换为 (32, 32, 6)。

import tensorflow as tf

input_tensor = tf.placeholder(dtype = tf.float32, shape = [None, 16, 16, 3])
print (input_tensor.shape)

my_filter = tf.get_variable('filter', shape = [2, 2, 6, 3], dtype = tf.float32)
conv = tf.nn.conv2d_transpose(input_tensor,
                              filter = my_filter,
                              output_shape = [-1, 32, 32, 6],
                              strides = [1, 2, 2, 1],
                              padding = 'SAME')
print (conv.shape)

将输出你:

(?, 16, 16, 3)
(?, 32, 32, 6)
于 2017-10-23T08:59:29.450 回答
0

只需在需要 train_batch_size 时使用 tf.shape(X_batch)[0]

于 2020-06-14T05:17:50.853 回答