使用 jax,我尝试计算每个样本的梯度,处理它们,然后将它们带入正常形式以计算正常参数更新。我的工作代码看起来像
differentiate_per_sample = jit(vmap(grad(loss), in_axes=(None, 0, 0)))
gradients = differentiate_per_sample(params, x, y)
# some code
gradients_summed_over_samples = []
for layer in gradients:
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
gradients_summed_over_samples.append((dw, db))
gradients
的形式在哪里list(tuple(DeviceArray(...), DeviceArray(...)), ...)
。
现在我尝试将循环重写为 vmap (不确定它是否最终会带来加速)
def sum_samples(layer):
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
vmap(sum_samples)(gradients)
但sum_samples
只调用一次,而不是为列表中的每个元素调用。
列表是问题还是我理解其他错误?