是否在 Keras 中实现了 Weldon 池化 [1]?
我可以看到它已由作者 [2] 在 pytorch 中实现,但找不到 keras 等效项。
[1] T. Durand、N. Thome 和 M. Cord。Weldon:深度卷积神经网络的弱监督学习。在 CVPR,2016 年。 [2] https://github.com/durandtibo/weldon.resnet.pytorch/tree/master/weldon
是否在 Keras 中实现了 Weldon 池化 [1]?
我可以看到它已由作者 [2] 在 pytorch 中实现,但找不到 keras 等效项。
[1] T. Durand、N. Thome 和 M. Cord。Weldon:深度卷积神经网络的弱监督学习。在 CVPR,2016 年。 [2] https://github.com/durandtibo/weldon.resnet.pytorch/tree/master/weldon
这是一个基于 lua 版本的版本(有一个 pytorch impl,但我认为取 max+min 的平均值存在错误)。我假设 lua 版本的最高最大值和最小值的平均值仍然正确。我还没有测试过整个自定义层方面,但已经足够接近了,欢迎评论。
托尼
class WeldonPooling(Layer):
"""Class to implement Weldon selective spacial pooling with negative evidence
"""
#@interfaces.legacy_global_pooling_support
def __init__(self, kmax, kmin=-1, data_format=None, **kwargs):
super(WeldonPooling, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
self.kmax=kmax
self.kmin=kmin
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_last':
return (input_shape[0], input_shape[3])
else:
return (input_shape[0], input_shape[1])
def get_config(self):
config = {'data_format': self.data_format}
base_config = super(_GlobalPooling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if self.data_format == "channels_last":
inputs = tf.transpose(inputs, [0, 3, 1, 2])
kmax=self.kmax
kmin=self.kmin
shape=tf.shape(inputs)
batch_size = shape[0]
num_channels = shape[1]
h = shape[2]
w = shape[3]
n = h * w
view = tf.reshape(inputs, [batch_size, num_channels, n])
sorted, indices = tf.nn.top_k(view, n, sorted=True)
#indices_max = tf.slice(indices,[0,0,0],[batch_size, num_channels, kmax])
output = tf.div(tf.reduce_sum(tf.slice(sorted,[0,0,0],[batch_size, num_channels, kmax]),2),kmax)
if kmin > 0:
#indices_min = tf.slice(indices,[0,0, n-kmin],[batch_size, num_channels, kmin])
output=tf.add(output,tf.div(tf.reduce_sum(tf.slice(sorted,[0,0,n-kmin],[batch_size, num_channels, kmin]),2),kmin))
return tf.reshape(output,[batch_size, num_channels])