例如,您设置了一个具有参数的模块。但是,如果您想在损失中规范化某些东西,那么模式是什么?
import jax.numpy as jnp
import jax
def loss(params, x, y):
l = jnp.sum((y - mlp.apply(params, x)) ** 2)
w = hk.get_params(params, 'w') # does not work like this
l += jnp.sum(w ** w)
return l
示例中缺少一些模式。