我找到了一个使用它的解决方案tf.gather_nd
,它可以工作,虽然它看起来不那么优雅。我使用了此处unravel_argmax
发布的功能。
def unravel_argmax(argmax, shape):
output_list = []
output_list.append(argmax // (shape[2] * shape[3]))
output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
return tf.stack(output_list)
def max_pool(input, ksize, strides,padding):
output, arg_max = tf.nn.max_pool_with_argmax(input=input,ksize=ksize,strides=strides,padding=padding)
shape = input.get_shape()
arg_max = tf.cast(arg_max,tf.int32)
unraveld = unravel_argmax(arg_max,shape)
indices = tf.transpose(unraveld,(1,2,3,4,0))
channels = shape[-1]
bs = tf.shape(iv.m)[0]
t1 = tf.range(channels,dtype=arg_max.dtype)[None, None, None, :, None]
t2 = tf.tile(t1,multiples=(bs,) + tuple(indices.get_shape()[1:-2]) + (1,1))
t3 = tf.concat((indices,t2),axis=-1)
t4 = tf.range(tf.cast(bs, dtype=arg_max.dtype))
t5 = tf.tile(t4[:,None,None,None,None],(1,) + tuple(indices.get_shape()[1:-2].as_list()) + (channels,1))
t6 = tf.concat((t5, t3), -1)
return tf.gather_nd(input,t6)
如果有人有更优雅的解决方案,我仍然很想知道。
垫