2

我有一个函数在compute(x)哪里。现在,我想用它来把它转换成一个需要一批数组的函数,然后加快它的速度。是这样的:xjnp.ndarrayvmapx[i]jitcompute(x)

def compute(x):
    # ... some code
    y = very_expensive_function(x)
    return y

但是,每个数组x[i]都有不同的长度。我可以通过用尾随零填充数组来轻松解决这个问题,这样它们都具有相同的长度N并且vmap(compute)可以应用于具有 shape 的批次(batch_size, N)

但是,这样做会导致very_expensive_function()在每个数组的尾随零上也被调用x[i]。有没有办法修改compute()这样的,very_expensive_function()只在切片上调用x,而不干扰vmapand jit

4

1 回答 1

4

使用 JAX,当您想要 jit 函数以加快速度时,给定的批处理参数x必须是定义良好的 ndarray(即 x[i] 必须具有相同的形状)。无论您是否使用vmap.

现在,通常的处理方法是填充这些数组。这意味着您在参数中添加掩码,以便填充值不会影响您的结果。例如,如果我想计算shapesoftmax的填充值,我需要“禁用”填充值的效果。这是一个例子:x(bath_size, max_length)

import jax.numpy as jnp
import jax

PAD = 0
MINUS_INFINITY = -1e6

x = jnp.array([ 
       [1, 2, 3, 4],
       [1, 2, PAD, PAD],
       [1, 2, 3, PAD]
    ])

mask = jnp.array([
           [1, 1, 1, 1],
           [1, 1, 0, 0],
           [1, 1, 1, 0]
       ])
       
masked_sofmax = jax.nn.softmax(x + (1-mask)*MINUS_INFINITY)    

它不像 padding 那样微不足道x。您需要在每一步实际更改计算以禁用填充的效果。在 softmax 的情况下,您可以通过将填充值设置为接近负无穷大来做到这一点。

最后,您无法真正提前知道使用或不使用 padding + masking 的速度性能是否会更好。根据我的经验,它通常会导致 CPU 的良好改进,以及 GPU 的非常大的改进。特别是,批次大小的选择对性能有很大的影响,因为在batch_size统计上,更高的 会导致更高的max_length,因此会导致对填充值执行更多的“无用”计算。

于 2021-07-26T15:48:07.933 回答