JAX的文档说,
并非所有 JAX 代码都可以 JIT 编译,因为它要求数组形状是静态的并且在编译时已知。
现在我有点惊讶,因为 tensorflow 具有类似tf.boolean_mask
JAX 在编译时似乎无法执行的操作。
- 为什么 TensorFlow 会出现这样的回归?我假设底层 XLA 表示在两个框架之间共享,但我可能弄错了。我不记得 Tensorflow 曾经在动态形状方面遇到过问题,而且诸如此类的功能
tf.boolean_mask
已经存在了很久。 - 我们可以期待这种差距在未来缩小吗?如果不是,为什么在 JAX 的 jit 中无法实现 Tensorflow(以及其他)所支持的功能?
编辑
梯度通过tf.boolean_mask
(显然不在掩码值上,它们是离散的);此处使用值未知的 TF1 样式图为例,因此 TF 不能依赖它们:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
x1 = tf.placeholder(tf.float32, (3,))
x2 = tf.placeholder(tf.float32, (3,))
y = tf.boolean_mask(x1, x2 > 0)
print(y.shape) # prints "(?,)"
dydx1, dydx2 = tf.gradients(y, [x1, x2])
assert dydx1 is not None and dydx2 is None