我有一个函数在compute(x)
哪里。现在,我想用它来把它转换成一个需要一批数组的函数,然后加快它的速度。是这样的:x
jnp.ndarray
vmap
x[i]
jit
compute(x)
def compute(x):
# ... some code
y = very_expensive_function(x)
return y
但是,每个数组x[i]
都有不同的长度。我可以通过用尾随零填充数组来轻松解决这个问题,这样它们都具有相同的长度N
并且vmap(compute)
可以应用于具有 shape 的批次(batch_size, N)
。
但是,这样做会导致very_expensive_function()
在每个数组的尾随零上也被调用x[i]
。有没有办法修改compute()
这样的,very_expensive_function()
只在切片上调用x
,而不干扰vmap
and jit
?