1

我希望能够在大小为 [N, s, K] 的数据张量上调用 tensorflow 的 tf.math.unsorted_segment_max。N 是通道数,K 是过滤器/特征图的数量。s 是单通道数据样本的大小。我有 s 大小的 segment_ids。例如,假设我的样本大小是 s=6,并且我想对两个元素进行最大处理(就像进行通常的最大池化一样,所以在第二个,整个数据张量的 s 维)。然后我的 segment_ids 等于 [0,0,1,1,2,2]。

我试着跑步

tf.math.unsorted_segment_max(data, segment_ids, num_segments)

segment_ids 具有扩展的 0 和 2 维,但由于段 id 然后重复,结果当然是大小 [3] 而不是我想要的 [N,3,K]。

所以我的问题是,如何构造一个合适的 segment_ids 张量,来实现我想要的?即根据原始 s 大小的 segment_ids 张量完成最大段,但在每个维度中分别?

基本上,回到这个例子,给定一维段 id 列表 seg_id=[0,0,1,1,2,2],我想构造一个类似于 segment_ids 张量的东西:

segment_ids[i,:,j] = seg_id + num_segments*(i*K + j) 

因此,当使用此张量作为段 id 调用 tf.math.(unsorted_)segment_max 时,我将得到大小为 [N, 3, K] 的结果,其效果与对每个数据运行 segment_max 的效果相同[ x,:,y] 分别并适当地堆叠结果。

这样做的任何方式都可以,只要它适用于 tensorflow。我猜想 tf.tile、tf.reshape 或 tf.concat 的组合应该可以解决问题,但我不知道如何,以什么顺序。另外,有没有更直接的方法呢?不需要在每个“池化”步骤中调整 segment_ids?

4

2 回答 2

2

我认为您可以通过以下方式实现您想要的tf.nn.pool

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    data = tf.constant([
        [
            [ 1, 12, 13],
            [ 2, 11, 14],
            [ 3, 10, 15],
            [ 4,  9, 16],
            [ 5,  8, 17],
            [ 6,  7, 18],
        ],
        [
            [19, 30, 31],
            [20, 29, 32],
            [21, 28, 33],
            [22, 27, 34],
            [23, 26, 35],
            [24, 25, 36],
        ]], dtype=tf.int32)
    segments = tf.constant([0, 0, 1, 1, 2, 2], dtype=tf.int32)
    pool = tf.nn.pool(data, [2], 'MAX', 'VALID', strides=[2])
    print(sess.run(pool))

输出:

[[[ 2 12 14]
  [ 4 10 16]
  [ 6  8 18]]

 [[20 30 32]
  [22 28 34]
  [24 26 36]]]

如果你真的想要我们tf.unsorted_segment_max,你可以按照你自己回答中的建议去做。这是一个避免转置并包括最终重塑的等效公式:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    data = ...
    segments = ...
    shape = tf.shape(data)
    n, k = shape[0], shape[2]
    m = tf.reduce_max(segments) + 1
    grid = tf.meshgrid(tf.range(n) * m * k,
                       segments * k,
                       tf.range(k), indexing='ij')
    segment_nd = tf.add_n(grid)
    segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
    result = tf.reshape(segmented, [n, m, k])
    print(sess.run(result))
    # Same output

就反向传播而言,这两种方法都应该在神经网络中正常工作。

编辑:就性能而言,池化似乎比分段总和更具可扩展性(正如人们所期望的那样):

import tensorflow as tf
import numpy as np

def method_pool(data, window):
    return tf.nn.pool(data, [window], 'MAX', 'VALID', strides=[window])

def method_segment(data, window):
    shape = tf.shape(data)
    n, s, k = shape[0], shape[1], shape[2]
    segments = tf.range(s) // window
    m = tf.reduce_max(segments) + 1
    grid = tf.meshgrid(tf.range(n) * m * k,
                       segments * k,
                       tf.range(k), indexing='ij')
    segment_nd = tf.add_n(grid)
    segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
    return tf.reshape(segmented, [n, m, k])

np.random.seed(100)
rand_data = np.random.rand(300, 500, 100)
window = 10
with tf.Graph().as_default(), tf.Session() as sess:
    data = tf.constant(rand_data, dtype=tf.float32)
    res_pool = method_pool(data, n)
    res_segment = method_segment(data, n)
    print(np.allclose(*sess.run([res_pool, res_segment])))
    # True
    %timeit sess.run(res_pool)
    # 2.56 ms ± 80.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    %timeit sess.run(res_segment)
    # 514 ms ± 6.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
于 2019-03-08T13:06:22.110 回答
0

我还没有想出任何更优雅的解决方案,但至少我想出了如何结合 tile、reshape 和 transpose 来做到这一点。我首先(使用提到的三个操作,参见下面的代码)构造一个与数据大小相同的张量,在张量中重复(但移位)原始 seg_id 向量的条目:

m = tf.reduce_max(seg_id) + 1
a = tf.constant([i*m for i in range(N*K) for j in range(s)])
b = tf.tile(seg_id, N*K)
#now reshape it:
segment_ids = tf.transpose(tf.reshape(a+b, shape=[N,K,s]), perm=[0,2,1])

这样就可以直接调用 segment_max 函数:

result = tf.unsorted_segment_max(data=data, segment_ids=segment_ids, num_segments=m*N*K)

它也做了我想要的,除了结果是扁平的,如果需要,需要再次重塑。等效地,您可以将原始数据张量重塑为 1d 张量,并在其上使用 a+b 作为 segment_ids 校准 segment_max。如果需要,再次重塑最终结果。

这感觉就像一个很长的路要结果......有没有更好的方法?我也不知道所描述的方式是否适合在 NN 内部使用,在反向传播期间......导数或计算图是否存在问题?有没有人对如何解决这个问题有更好的想法?

于 2019-03-08T12:04:59.313 回答