我正在尝试使用 tensorflow 的tf.map_fn来映射一个参差不齐的张量,但我遇到了一个我无法修复的错误。这是一些演示错误的最小代码:
import tensorflow as tf
X = tf.ragged.constant([[0,1,2], [0,1]])
def outer_product(x):
return x[...,None]*x[None,...]
tf.map_fn(outer_product, X)
我想要的输出是:
tf.ragged.constant([
[[0, 0, 0],
[0, 1, 2],
[0, 2, 4]],
[[0, 0],
[0, 1]]
])
我得到的错误是:
“InvalidArgumentError:所有 flat_values 必须具有兼容的形状。索引 0 处的形状:[3]。索引 1 处的形状:[2]。如果您使用的是 tf.map_fn,那么您可能需要使用适当的 ragged_rank 指定显式 fn_output_signature,并且/ 或将输出张量转换为 RaggedTensors。[Op:RaggedTensorFromVariant]"
我意识到我需要指定 fn_output_signature 但尽管进行了实验,但我无法弄清楚它应该是什么。
编辑:我稍微清理了 AloneTogether 的优秀答案,并创建了一个映射参差不齐的张量的函数。他的回答使用tf.ragged.stack
函数将张量转换为不规则的张量,这tf.map_fn
出于某种原因需要
def ragged_map_fn(func, t):
def new_func(t):
return tf.ragged.stack(func(t),0)
signature = tf.type_spec_from_value(new_func(t[0]))
ans = tf.map_fn(new_func, t, fn_output_signature=signature)
ans = tf.squeeze(ans, 1)
return ans