感谢@jakevdp本人促使我考虑一些替代的谷歌查询,事实证明,截至https://github.com/google/jax/pull/484,grad函数有一个aux选项。我认为这对于迁移到 jax 的 tensorflow 2 用户来说并不是很明显,因为您明确使用 GradientTape 的方式。
类似以下示例的内容显示了返回的辅助信息。它甚至似乎处理了一个 dict,这对于在更新循环中定期记录很有用。
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
theta = jax.random.normal(key, (10, 1))
y = np.random.randn(10, 1)
alpha = 0.01
def loss(theta, y):
loss_reg = jnp.sum(theta ** 2)
loss_data = jnp.sum((y - theta) ** 2)
loss = loss_data + alpha * loss_reg
return loss, dict(loss_reg=loss_reg, loss_data=loss_data)
grad, aux = jax.grad(loss, has_aux=True)(theta, y)
display(grad)
display(aux)
try:
jax.grad(loss)(theta, y)
except TypeError as e:
print(f'yes got error {e}')
输出:
DeviceArray([[-1.4899637 ],
[-0.71481365],
[-0.6030376 ],
[-0.8263864 ],
[-1.8103108 ],
[ 0.69435316],
[-1.5611547 ],
[-1.6380725 ],
[ 0.9838154 ],
[ 0.21186407]], dtype=float32)
{'loss_data': DeviceArray(3.3714797, dtype=float32),
'loss_reg': DeviceArray(2.658556, dtype=float32)}
yes got error Gradient only defined for scalar-output functions. Output was (DeviceArray(3.3980653, dtype=float32), {'loss_data': DeviceArray(3.3714797, dtype=float32), 'loss_reg': DeviceArray(2.658556, dtype=float32)}).