1

我有一个 2D RaggedTensor,其中包含我想要的完整张量每一行的索引,例如:

[
    [0,4],
    [1,2,3],
    [5]
]

进入

[
    [200, 305, 400, 20, 20, 105],
    [200, 315, 401, 20, 20, 167],
    [200, 7, 402, 20, 20, 105],
]

[
    [200,20],
    [315,401,20],
    [105]
]

我怎样才能以最有效的方式实现这一点(最好只使用tf函数)?我相信像这样gather_nd的东西能够使用 RaggedTensors,但我无法弄清楚它是如何工作的。

4

1 回答 1

1

您可以将tf.gather, 与batch_dims关键字参数一起使用:

>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200, 20], [315, 401, 20], [105]]>
于 2021-04-16T13:44:29.433 回答