4

我正在使用JAX来实现一个简单的神经网络 (NN),并且我想在 NN 运行后访问并保存反向传播的梯度以供进一步分析。我可以使用 python 调试器临时访问和查看渐变(只要我不使用 jit)。但我想保存整个训练过程中的所有梯度,并在训练完成后对其进行分析。我为此使用 id_tap 和全局变量想出了一个相当老套的解决方案(参见下面的代码)。但我想知道是否有更好的解决方案不违反 JAX 的功能原则。

非常感谢!

import jax.numpy as jnp
from jax import grad, jit, vmap, random, custom_vjp
from jax.experimental.host_callback import id_tap

# experimental solution
global_save_list = {'x':[],'w':[],'g':[],'des':[]}
def global_save_func(ctx, des):
    x, w, g = ctx
    global_save_list['x'].append(x)
    global_save_list['w'].append(w)
    global_save_list['g'].append(g)
    global_save_list['des'].append(des)


@custom_vjp
def qmvm(x, w):
    return jnp.dot(x, w)

def qmvm_fwd(x, w):
    return qmvm(x, w), (x, w)

def qmvm_bwd(ctx, g):
    x, w = ctx

    # here I would like to save gradients g - or at least running statistics of them

    # experimental solution with id_tap
    id_tap(global_save_func, ((x, w, g)))

    fwd_grad = jnp.dot(g, w.transpose())
    w_grad = jnp.dot(x, g.transpose())
    
    return fwd_grad, w_grad

qmvm.defvjp(qmvm_fwd, qmvm_bwd)

def run_nn(x, w):
    out = qmvm(x, w)   # 1st MVM
    out = qmvm(out, w) # 2nd MVM
    return out

run_nn_batched = vmap(run_nn)

@jit
def loss(x, w, target):
    out = run_nn_batched(x, w)
    return jnp.sum((out - target)**2)

key = random.PRNGKey(42)
subkey1, subkey2, subkey3 = random.split(key, 3)

A = random.uniform(subkey1, (10, 10, 10), minval = -10, maxval = 10)
B = random.uniform(subkey2, (10, 10, 10), minval = -10, maxval = 10)
C = random.uniform(subkey3, (10, 10, 10), minval = -10, maxval = 10)

for e in range(10):
    gval = grad(loss, argnums = 0)(A, B, C)
    # some type of update rule

# here I would like to access gradients, preferably knowing to which MVM (1st or 2nd) and example they belong

# experimental solution:
print(global_save_list) 
4

0 回答 0