0

有没有办法根据 XLA 编译函数中的随机数生成器动态切片张量?例如:

@tf.function(experimental_compile=True)
def random_slice(input, max_slice_size):
    offset = tf.squeeze(tf.random.uniform([1], minval=0, maxval=input.shape[0]-max_slice_size, dtype=tf.int32))
    sz = tf.squeeze(tf.random.uniform([1], minval=1, maxval=max_slice_size, dtype=tf.int32))

    indices = tf.range(offset, offset+sz)  # Non-XLA-able due to non-static bounds

    return tf.gather(input, indices)

x = tf.ones([50, 50])
y = random_slice(x, 4)

此代码无法编译,因为 XLA 要求tf.range在编译时知道参数。有推荐的解决方法吗?

4

1 回答 1

0

这里的根本问题是 XLA 需要静态地知道Tensor程序中所有 s 的形状。在这种情况下,它会抱怨,tf.range因为在给定随机输入的情况下,它的输出是不可知的。相反,您可能能够摆脱生成屏蔽版本(将不需要的元素归零,使用类似 tensor_scatter_nd_update 之类的东西)并在下游使用该屏蔽版本(很难确切地说出如何,没有看到更多关于如何y的上下文使用)。

于 2020-05-29T21:31:16.107 回答