我有一个函数在compute(x)哪里。现在,我想用它来把它转换成一个需要一批数组的函数,然后加快它的速度。是这样的:xjnp.ndarrayvmapx[i]jitcompute(x)
def compute(x):
# ... some code
y = very_expensive_function(x)
return y
但是,每个数组x[i]都有不同的长度。我可以通过用尾随零填充数组来轻松解决这个问题,这样它们都具有相同的长度N并且vmap(compute)可以应用于具有 shape 的批次(batch_size, N)。
但是,这样做会导致very_expensive_function()在每个数组的尾随零上也被调用x[i]。有没有办法修改compute()这样的,very_expensive_function()只在切片上调用x,而不干扰vmapand jit?