2

在尝试对大型数组进行 SVD 压缩时,我在 Jax 中遇到了一些我不理解的行为。这是示例代码:

@jit 
def jax_compress(L):
    U, S, _ = jsc.linalg.svd(L, 
    full_matrices = False,
    lapack_driver = 'gesvd',
    check_finite=False,
    overwrite_a=True)

    maxS=jnp.max(S)
    chi = jnp.sum(S/maxS>1E-1)

    return chi, jnp.asarray(U)

在考虑这段代码时,Jax/jit 比 SciPy 提供了巨大的性能提升,但最终我想减少 U 的维数,我通过将它包装在函数中来做到这一点:

def jax_process(A):

    chi, U = jax_compress(A)
    
    return U[:,0:chi]

这一步在计算时间方面的成本令人难以置信,比 SciPy 的等价物更昂贵,从这个比较中可以看出:

jax 和 scipy 的基准测试

sc_compress并且sc_process是上面 jax 代码的 SciPy 等价物。如您所见,在 SciPy 中对数组进行切片几乎不需要任何成本,但在应用于 hit 函数的输出时却非常昂贵。有人对这种行为有一些了解吗?

4

2 回答 2

1

我对 JAX 和 PyTorch 之间的切片速度进行了类似的比较。dynamic_slice比普通切片快得多,但仍然比火炬中的同等切片要慢得多。由于我是 JAX 新手,我不确定原因是什么,但这可能与复制与引用有关,因为 JAX 数组是不可变的。

JAX(没有@jit)

key = random.PRNGKey(0)
j = random.normal(key, (32, 2, 1024, 1024, 3))
%timeit j[..., 100:600, 100:600, :].block_until_ready()
%timeit dynamic_slice(j, [0, 0, 100, 100, 0], [32, 2, 500, 500, 3]).block_until_ready()
2.78 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
993 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

PyTorch

t = torch.randn((32, 2, 1024, 1024, 3)).cuda()

%%timeit 
t[..., 100:600, 100:600, :]
torch.cuda.synchronize()
7.63 µs ± 22.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
于 2021-04-21T11:07:15.743 回答
0

我不是 Jax 专家,我不确定它在幕后是如何工作的,但我运行了那个片段并看了看。

我很确定内部的 Jax 函数jax_compress(或来自 jit 装饰器的效果)是延迟评估的,因此只有当您在计算结束时“查看内部”输出矩阵并实际要求时,它们才会执行完整计算具体数字(很像 python 生成器做的事情,以及像 Haskell 这样的函数式语言)。

我认为你在最后做的数组切片基本上是对你的矩阵“提出一个具体问题”的一种形式。

您可以通过jax_compress自行定时函数并在访问数组元素后检查这一点:

ti = time.time()
X, U = jax_compress(A)
# very fast
print(f"Compession takes {time.time() - ti} seconds when not peeking")

ti = time.time()
X, U = jax_compress(A)
# much slower
print(U[0][0])
print(f"Compession takes {time.time() - ti} seconds when peeking")

一种解决方案可能是使用lax.dynamic_sliceor lax.dynamic_update_slice,我相信其中有一个 Jax 实现jax.numpy.lax_numpy。但是,根据您的硬件,我的直觉是您不会发现太多的加速,因为 SVD 的 scipy 实现无论如何都是高度优化和编译的 Fortran 代码的包装器(对于单个 CPU 机器)。

于 2020-11-09T15:12:15.477 回答