1

例如,您设置了一个具有参数的模块。但是,如果您想在损失中规范化某些东西,那么模式是什么?

import jax.numpy as jnp
import jax
def loss(params, x, y):
   l = jnp.sum((y - mlp.apply(params, x)) ** 2)
   w = hk.get_params(params, 'w') # does not work like this
   l += jnp.sum(w ** w)
   return l

示例中缺少一些模式。

4

1 回答 1

1

params本质上是一个只读字典,所以你可以通过把它当作字典来获取参数的值:

print(params['w'])

如果要更新参数,则不能就地进行,而必须先将其转换为可变字典:

params_mutable = hk.data_structures.to_mutable_dict(params)
params_mutable['w'] = 3.14
params_new = hk.data_structures.to_immutable_dict(params_mutable)
于 2021-09-03T17:36:29.063 回答