我有一个张量的形状(16, 4096, 3)。我有另一个形状指数张量(16, 32768, 3)。我正在尝试收集这些值dim=1。这最初是在 pytorch 中使用收集功能完成的,如下所示-
# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)
请注意,输出的大小b与idx. 但是,当我应用gathertensorflow 的功能时,我得到了完全不同的输出。发现输出维度不匹配,如下所示 -
b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)
我也尝试使用tf.gather_nd但徒劳无功。见下文-
b = tf.gather_nd(a, idx)
# b.shape (16, 32768)
为什么我会得到不同形状的张量?我想得到与 pytorch 计算的相同形状的张量。
换句话说,我想知道torch.gather的tensorflow等价物。