1

我有一个从 1 到 10k 的 row_lens 的 raggedTensor。我想以可扩展的方式从中随机选择元素,每行的数量有上限。就像在这个例子中一样:

vect = [[1,2,3],[4,5][6],[7,8,9,10,11,12,13]]
limit = 3
sample(vect, limit)

-> 输出:[[1,2,3],[4,5],[6],[7,9,11]]

我的想法是在 len_row < limit 的情况下选择 * ,在另一种情况下随机选择。我想知道这是否可以通过一些 tensorflow 操作以低于 batch_size 的复杂度来完成?

4

1 回答 1

1

您可以尝试tf.map_fn在图形模式下使用:

import tensorflow as tf

vect = tf.ragged.constant([[1,2,3],[4,5],[6],[7,8,9,10,11,12,13]])

@tf.function
def sample(x, samples=3):
  length = tf.shape(x)[0]
  x = tf.cond(tf.less_equal(length, samples), lambda: x, lambda: tf.gather(x, tf.random.shuffle(tf.range(length))[:samples]))
  return x

c = tf.map_fn(sample, vect)
<tf.RaggedTensor [[1, 2, 3], [4, 5], [6], [12, 7, 9]]>

请注意,tf.vectorized_map这可能会更快,但当前存在关于此函数和参差不齐的张量的错误。使用tf.while_loop也是一种选择。

于 2022-02-11T12:23:33.307 回答