-1

例如,jax.experimental.stax有一个像这样实现的密集层:

def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
  """Layer constructor function for a dense (fully-connected) layer."""
  def init_fun(rng, input_shape):
    output_shape = input_shape[:-1] + (out_dim,)
    k1, k2 = random.split(rng)
    W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
    return output_shape, (W, b)
  def apply_fun(params, inputs, **kwargs):
    W, b = params
    return jnp.dot(inputs, W) + b
  return init_fun, apply_fun

例如,如果我们将bias 实现为允许为None,或者params 的长度为1,那么就会对grad 的工作方式产生影响。

在这里应该瞄准的模式是什么?jax.jit有一个static_argnums我想可以与一些has_bias参数一起使用的,但涉及到簿记,我相信某处必须有一些例子。

4

1 回答 1

0

这不行吗?

def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
  """Layer constructor function for a dense (fully-connected) layer."""

  def init_fun(rng, input_shape):
    output_shape = input_shape[:-1] + (out_dim,)
    k1, k2 = random.split(rng)
    W = W_init(k1, (input_shape[-1], out_dim))
    if b_init:
        b = b_init(k2, (out_dim,)
        return output_shape, (W, b)
    return output_shape, W

  def apply_fun(params, inputs, **kwargs):
    if len(params) == 1:
        W = params
        return jnp.dot(inputs, W)
    else:
        W, b = params
        return jnp.dot(inputs, W) + b

  return init_fun, apply_fun
于 2021-09-14T09:22:34.047 回答