2

我自定义了一个层,将batch_size和第一个维度合并,其他维度保持不变,但是compute_output_shape好像没有效果,导致后续层无法获取准确的形状信息,导致报错。如何使 compute_output_shape 工作?

import keras
from keras import backend as K

class BatchMergeReshape(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(BatchMergeReshape, self).__init__(**kwargs)

    def build(self, input_shape):
        super(BatchMergeReshape, self).build(input_shape)  

    def call(self, x):
        input_shape = K.shape(x)
        batch_size, seq_len = input_shape[0], input_shape[1]
        r = K.reshape(x, (batch_size*seq_len,)+input_shape[2:])
        print("call_shape:",r.shape)
        return r

    def compute_output_shape(self, input_shape):
        if input_shape[0] is None:
            r = (None,)+input_shape[2:]
            print("compute_output_shape:",r)
            return r
        else:
            r = (input_shape[0]*input_shape[1],)+input_shape[2:]
            return r

a = keras.layers.Input(shape=(3,4,5))
b = BatchMergeReshape()(a)
print(b.shape)

# call_shape: (?, ?)
# compute_output_shape: (None, 4, 5)
# (?, ?)

我需要得到 (None,4,5) 但得到 (None,None),为什么 compute_output_shape 不起作用。我的 keras 版本是 2.2.4

4

1 回答 1

1

问题可能是K.shape返回张量而不是元组。你不能这样做(batch_size*seq_len,) + input_shape[2:]。这是混合了很多东西,张量和元组,结果肯定是错误的。

现在的好处是,如果您知道其他维度而不知道批量大小,您只需要这一层:

Lambda(lambda x: K.reshape(x, (-1,) + other_dimensions_tuple))

如果你不这样做:

input_shape = K.shape(x)
new_batch_size = input_shape[0:1] * input_shape[1:2] #needs to keep a shape of an array   
                 #new_batch_size.shape = (1,)
new_shape = K.concatenate([new_batch_size, input_shape[2:]]) #this is a tensor   
                                                             #result of concatenating 2 tensors   

r = K.reshape(x, new_shape)

请注意,这在 Tensorflow 中有效,但在 Theano 中可能无效。

另请注意,Keras 将要求模型输出的批量大小等于模型输入的批量大小。这意味着您需要在模型结束之前恢复原始批量大小。

于 2019-09-24T03:58:52.340 回答