2

JAX的文档说,

并非所有 JAX 代码都可以 JIT 编译,因为它要求数组形状是静态的并且在编译时已知。

现在我有点惊讶,因为 tensorflow 具有类似tf.boolean_maskJAX 在编译时似乎无法执行的操作。

  1. 为什么 TensorFlow 会出现这样的回归?我假设底层 XLA 表示在两个框架之间共享,但我可能弄错了。我不记得 Tensorflow 曾经在动态形状方面遇到过问题,而且诸如此类的功能tf.boolean_mask已经存在了很久。
  2. 我们可以期待这种差距在未来缩小吗?如果不是,为什么在 JAX 的 jit 中无法实现 Tensorflow(以及其他)所支持的功能?

编辑

梯度通过tf.boolean_mask(显然不在掩码值上,它们是离散的);此处使用值未知的 TF1 样式图为例,因此 TF 不能依赖它们:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

x1 = tf.placeholder(tf.float32, (3,))
x2 = tf.placeholder(tf.float32, (3,))
y = tf.boolean_mask(x1, x2 > 0)
print(y.shape)  # prints "(?,)"
dydx1, dydx2 = tf.gradients(y, [x1, x2])
assert dydx1 is not None and dydx2 is None
4

2 回答 2

1

目前,您不能如此处所述

这不是 JAX jit vs TensorFlow 的限制,而是 XLA 的限制,或者更确切地说是两者如何编译的限制。

JAX 仅使用 XLA 来编译该函数。XLA需要知道静态形状。这是XLA中固有的设计选择。

TensorFlow 使用function: 这创建了一个图形,该图形可以具有静态未知的形状。这不如使用 XLA 高效,但仍然可以。但是,tf.function提供了一个选项jit_compile,它将使用 XLA 编译函数内部的图形。虽然这通常会提供不错的加速(免费),但它有一些限制:形状需要静态已知(惊喜、惊喜……)

这总体上不是太令人惊讶的行为:计算机中的计算通常更快(给定一个体面的优化器),以前称为更多参数(内存布局,...)可以优化调度。知道的越少,代码就越慢(在这端是普通的 Python)。

于 2021-11-30T18:36:06.430 回答
0

我认为 JAX 并不比 TensorFlow 更无能为力。没有什么禁止你在 JAX 中这样做:

new_array = my_array[mask]

但是,mask应该是索引(整数)而不是布尔值。这样,JAX 就知道new_array( 与 相同mask) 的形状。从这个意义上说,我很确定它tf.boolean_mask是不可微的,即如果你尝试在某个点计算它的梯度,它会引发错误。

更一般地说,如果您需要屏蔽数组,无论您使用什么库,有两种方法:

  1. 如果您事先知道需要选择哪些索引并且您需要提供这些索引以便库可以在编译之前计算形状;
  2. 如果您无法定义这些索引,无论出于何种原因,那么您需要设计您的代码以避免防止填充影响您的结果。

每种情况的示例

  1. 假设您正在 JAX 中编写一个简单的嵌入层。是对应于几个句子的input一批标记索引。为了获得与这些索引相对应的词嵌入,我将简单地编写word_embeddings = embeddings[input]. 由于我事先不知道句子的长度,所以我需要预先将所有标记序列填充到相同的长度,这样input就是 shape (number_of_sentences, sentence_max_length)。现在,每当这个形状发生变化时,JAX 都会编译屏蔽操作。为了尽量减少编译次数,您可以提供相同数量的句子(也称为批量大小),您可以将其设置sentence_max_length为整个语料库中的最大句子长度。这样,训练期间将只有一个编译。当然,你需要在里面预留一排word_embeddings对应于焊盘索引。但是,掩蔽仍然有效。

  2. 稍后在模型中,假设您想将每个句子的每个单词表示为句子中所有其他单词的加权平均值(如自我注意机制)。权重是为整个批次并行计算的,并存储在A维度矩阵中(number_of_sentences, sentence_max_length, sentence_max_length)。加权平均值使用公式计算A @ word_embeddings。现在,您需要确保填充标记不会影响前面的公式。为此,您可以将 A 中与焊盘索引相对应的条目清零,以消除它们对平均的影响。如果填充令牌索引为 0,您将执行以下操作:

    mask = jnp.array(input > 0, dtype=jnp.float32)
    A = A * mask[:, jnp.newaxis, :]
    weighted_mean = A @ word_embeddings 

所以这里我们使用了一个布尔掩码,但是掩码在某种程度上是可微的,因为我们将掩码与另一个矩阵相乘,而不是将其用作索引。请注意,我们应该以相同的方式删除weighted_mean也对应于填充标记的行。

于 2021-05-05T15:01:09.090 回答