我正在使用JAX来实现一个简单的神经网络 (NN),并且我想在 NN 运行后访问并保存反向传播的梯度以供进一步分析。我可以使用 python 调试器临时访问和查看渐变(只要我不使用 jit)。但我想保存整个训练过程中的所有梯度,并在训练完成后对其进行分析。我为此使用 id_tap 和全局变量想出了一个相当老套的解决方案(参见下面的代码)。但我想知道是否有更好的解决方案不违反 JAX 的功能原则。
非常感谢!
import jax.numpy as jnp
from jax import grad, jit, vmap, random, custom_vjp
from jax.experimental.host_callback import id_tap
# experimental solution
global_save_list = {'x':[],'w':[],'g':[],'des':[]}
def global_save_func(ctx, des):
x, w, g = ctx
global_save_list['x'].append(x)
global_save_list['w'].append(w)
global_save_list['g'].append(g)
global_save_list['des'].append(des)
@custom_vjp
def qmvm(x, w):
return jnp.dot(x, w)
def qmvm_fwd(x, w):
return qmvm(x, w), (x, w)
def qmvm_bwd(ctx, g):
x, w = ctx
# here I would like to save gradients g - or at least running statistics of them
# experimental solution with id_tap
id_tap(global_save_func, ((x, w, g)))
fwd_grad = jnp.dot(g, w.transpose())
w_grad = jnp.dot(x, g.transpose())
return fwd_grad, w_grad
qmvm.defvjp(qmvm_fwd, qmvm_bwd)
def run_nn(x, w):
out = qmvm(x, w) # 1st MVM
out = qmvm(out, w) # 2nd MVM
return out
run_nn_batched = vmap(run_nn)
@jit
def loss(x, w, target):
out = run_nn_batched(x, w)
return jnp.sum((out - target)**2)
key = random.PRNGKey(42)
subkey1, subkey2, subkey3 = random.split(key, 3)
A = random.uniform(subkey1, (10, 10, 10), minval = -10, maxval = 10)
B = random.uniform(subkey2, (10, 10, 10), minval = -10, maxval = 10)
C = random.uniform(subkey3, (10, 10, 10), minval = -10, maxval = 10)
for e in range(10):
gval = grad(loss, argnums = 0)(A, B, C)
# some type of update rule
# here I would like to access gradients, preferably knowing to which MVM (1st or 2nd) and example they belong
# experimental solution:
print(global_save_list)