在尝试对大型数组进行 SVD 压缩时,我在 Jax 中遇到了一些我不理解的行为。这是示例代码:
@jit
def jax_compress(L):
U, S, _ = jsc.linalg.svd(L,
full_matrices = False,
lapack_driver = 'gesvd',
check_finite=False,
overwrite_a=True)
maxS=jnp.max(S)
chi = jnp.sum(S/maxS>1E-1)
return chi, jnp.asarray(U)
在考虑这段代码时,Jax/jit 比 SciPy 提供了巨大的性能提升,但最终我想减少 U 的维数,我通过将它包装在函数中来做到这一点:
def jax_process(A):
chi, U = jax_compress(A)
return U[:,0:chi]
这一步在计算时间方面的成本令人难以置信,比 SciPy 的等价物更昂贵,从这个比较中可以看出:
sc_compress
并且sc_process
是上面 jax 代码的 SciPy 等价物。如您所见,在 SciPy 中对数组进行切片几乎不需要任何成本,但在应用于 hit 函数的输出时却非常昂贵。有人对这种行为有一些了解吗?