我认为您可以通过以下方式实现您想要的tf.nn.pool
:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
data = tf.constant([
[
[ 1, 12, 13],
[ 2, 11, 14],
[ 3, 10, 15],
[ 4, 9, 16],
[ 5, 8, 17],
[ 6, 7, 18],
],
[
[19, 30, 31],
[20, 29, 32],
[21, 28, 33],
[22, 27, 34],
[23, 26, 35],
[24, 25, 36],
]], dtype=tf.int32)
segments = tf.constant([0, 0, 1, 1, 2, 2], dtype=tf.int32)
pool = tf.nn.pool(data, [2], 'MAX', 'VALID', strides=[2])
print(sess.run(pool))
输出:
[[[ 2 12 14]
[ 4 10 16]
[ 6 8 18]]
[[20 30 32]
[22 28 34]
[24 26 36]]]
如果你真的想要我们tf.unsorted_segment_max
,你可以按照你自己回答中的建议去做。这是一个避免转置并包括最终重塑的等效公式:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
data = ...
segments = ...
shape = tf.shape(data)
n, k = shape[0], shape[2]
m = tf.reduce_max(segments) + 1
grid = tf.meshgrid(tf.range(n) * m * k,
segments * k,
tf.range(k), indexing='ij')
segment_nd = tf.add_n(grid)
segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
result = tf.reshape(segmented, [n, m, k])
print(sess.run(result))
# Same output
就反向传播而言,这两种方法都应该在神经网络中正常工作。
编辑:就性能而言,池化似乎比分段总和更具可扩展性(正如人们所期望的那样):
import tensorflow as tf
import numpy as np
def method_pool(data, window):
return tf.nn.pool(data, [window], 'MAX', 'VALID', strides=[window])
def method_segment(data, window):
shape = tf.shape(data)
n, s, k = shape[0], shape[1], shape[2]
segments = tf.range(s) // window
m = tf.reduce_max(segments) + 1
grid = tf.meshgrid(tf.range(n) * m * k,
segments * k,
tf.range(k), indexing='ij')
segment_nd = tf.add_n(grid)
segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
return tf.reshape(segmented, [n, m, k])
np.random.seed(100)
rand_data = np.random.rand(300, 500, 100)
window = 10
with tf.Graph().as_default(), tf.Session() as sess:
data = tf.constant(rand_data, dtype=tf.float32)
res_pool = method_pool(data, n)
res_segment = method_segment(data, n)
print(np.allclose(*sess.run([res_pool, res_segment])))
# True
%timeit sess.run(res_pool)
# 2.56 ms ± 80.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit sess.run(res_segment)
# 514 ms ± 6.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)