假设J是某个函数f关于某些参数的雅可比行列式。是否有有效的方法(在 PyTorch 或 Jax 中)拥有一个接受两个输入(x1和x2)并在J(x1)*J(x2).transpose() 不实例化内存中的整个J矩阵的情况下进行计算的函数?
我遇到过类似jvp(f, input, v=vjp(f, input))但不太理解的东西,也不确定是我想要的。
假设J是某个函数f关于某些参数的雅可比行列式。是否有有效的方法(在 PyTorch 或 Jax 中)拥有一个接受两个输入(x1和x2)并在J(x1)*J(x2).transpose() 不实例化内存中的整个J矩阵的情况下进行计算的函数?
我遇到过类似jvp(f, input, v=vjp(f, input))但不太理解的东西,也不确定是我想要的。