假设J
是某个函数f
关于某些参数的雅可比行列式。是否有有效的方法(在 PyTorch 或 Jax 中)拥有一个接受两个输入(x1
和x2
)并在J(x1)*J(x2).transpose()
不实例化内存中的整个J
矩阵的情况下进行计算的函数?
我遇到过类似jvp(f, input, v=vjp(f, input))
但不太理解的东西,也不确定是我想要的。
假设J
是某个函数f
关于某些参数的雅可比行列式。是否有有效的方法(在 PyTorch 或 Jax 中)拥有一个接受两个输入(x1
和x2
)并在J(x1)*J(x2).transpose()
不实例化内存中的整个J
矩阵的情况下进行计算的函数?
我遇到过类似jvp(f, input, v=vjp(f, input))
但不太理解的东西,也不确定是我想要的。