1

假设J是某个函数f关于某些参数的雅可比行列式。是否有有效的方法(在 PyTorch 或 Jax 中)拥有一个接受两个输入(x1x2)并在J(x1)*J(x2).transpose() 实例化内存中的整个J矩阵的情况下进行计算的函数?

我遇到过类似jvp(f, input, v=vjp(f, input))但不太理解的东西,也不确定是我想要的。

4

1 回答 1

2
于 2020-09-17T21:34:32.350 回答