例如,jax.experimental.stax
有一个像这样实现的密集层:
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return jnp.dot(inputs, W) + b
return init_fun, apply_fun
例如,如果我们将bias 实现为允许为None,或者params 的长度为1,那么就会对grad 的工作方式产生影响。
在这里应该瞄准的模式是什么?jax.jit
有一个static_argnums
我想可以与一些has_bias
参数一起使用的,但涉及到簿记,我相信某处必须有一些例子。