0

这个 Tensorflow 文档给出了tf.map_fn在不规则张量上使用的示例,该示例适用于 Tensorflow 2.4.1 及更高版本:

digits = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
print(tf.map_fn(tf.math.square, digits))

但是,以下示例在 Tensorflow 2.4.1 或 Tensorflow 2.5 中运行时会导致错误“'RaggedTensor' 类型的对象没有 len”:

import tensorflow as tf

X=tf.ragged.constant([[1.,2.],[3.,4.,5.]], dtype=tf.float32)

@tf.function
def powerX(i):
    global X
    return X**i

Y = tf.map_fn(powerX, tf.range(3, dtype=tf.float32))

有没有办法使这项工作?我不明白抛出的错误。一般来说,我试图通过映射一个用户定义的函数来获得完全的并行性,该函数在一个参差不齐的张量上只有 Tensorflow 操作,结果是参差不齐的张量。

4

1 回答 1

1

tf.map_fn需要输出签名。我不确定为什么它不能从输入中推断出这一点,但这是 tensorflow 人的问题。以下代码将为您工作。

import tensorflow as tf

X=tf.ragged.constant([[1.,2.],[3.,4.,5.]], dtype=tf.float32)

@tf.function
def powerX(i):
    global X
    return X**i
signature = tf.type_spec_from_value(powerX(X))
Y = tf.map_fn(powerX, tf.range(3, dtype=tf.float32),fn_output_signature=signature)
于 2021-12-21T02:42:35.083 回答