0

我已经训练了一个现成的 Transformer()。

现在我想使用编码器来构建分类器。为此,我只想使用第一个令牌的输出(bert-style cls-token-result)并通过密集层运行它。

我所做的:

tl.Serial(encoder, tl.Fn('pooler', lambda x: (x[:, 0, :])), tl.Dense(7))

形状: 编码 器 给我形状(64、50、512)64 = batch_size, 50 = seq_len, 512 = model_dim

pooler 为我提供了符合预期和期望的形状(64, 512)。

密集层应该为每个批次成员采用 512 个维度并分类超过 7 个类。但我猜 trax/jax 仍然希望它的长度为 seq_len (50)。

TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [50].

我想念什么?

完整追溯:

Traceback (most recent call last):
  File "mikado_classes.py", line 2054, in <module>
    app.run(main)
  File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "mikado_classes.py", line 1153, in main
    loop_neu.run(2)
  File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 361, in run
    loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
  File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 483, in _run_one_step
    batch, rng, step=step, learning_rate=learning_rate
  File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 134, in one_step
    (weights, self._slots), step, self._opt_params, batch, state, rng)
  File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 173, in single_device_update_fn
    batch, weights, state, rng)
  File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py", line 549, in pure_fn
    self._caller, signature(x), trace) from None
jax._src.traceback_util.FilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 865
  layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/mikado_classes.py, line 1134
  layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Dense_7 (in pure_fn):
  layer created in file [...]/mikado_classes.py, line 1133
  layer input shapes: ShapeDtype{shape:(64, 512), dtype:float32}

  File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
    y = forward(self, x, *args, **kwargs)

  File [...]/trax/layers/core.py, line 95, in forward
    return jnp.dot(x, w) + b  # Affine map.

  File [...]/_src/numpy/lax_numpy.py, line 3498, in dot
    return lax.dot_general(a, b, (contract_dims, batch_dims), precision)

  File [...]/_src/lax/lax.py, line 674, in dot_general
    preferred_element_type=preferred_element_type)

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 285, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 458, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 140, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 147, in default_process_primitive
    return primitive.bind(*consts, **params)

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1058, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 1992, in standard_abstract_eval
    shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)

  File [...]/_src/lax/lax.py, line 3090, in _dot_general_shape_rule
    raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))

TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [50].

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "mikado_classes.py", line 2054, in <module>
    app.run(main)
  File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "mikado_classes.py", line 1153, in main
    loop_neu.run(2)
  File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 361, in run
    loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
  File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 483, in _run_one_step
    batch, rng, step=step, learning_rate=learning_rate
  File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 134, in one_step
    (weights, self._slots), step, self._opt_params, batch, state, rng)
  File "/root/.local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 398, in f_jitted
    return cpp_jitted_f(context, *args, **kwargs)
  File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 295, in cache_miss
    donated_invars=donated_invars)
  File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 1275, in bind
    return call_bind(self, fun, *args, **params)
  File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 1266, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 1278, in process
    return trace.process_call(self, fun, tracers, params)
  File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 631, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py", line 260, in memoized_fun
    ans = call(fun, *args)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 656, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1216, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1196, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 173, in single_device_update_fn
    batch, weights, state, rng)
  File "/root/.local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 810, in value_and_grad_f
    ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
  File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 1918, in _vjp
    out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 116, in vjp
    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 101, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 506, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py", line 549, in pure_fn
    self._caller, signature(x), trace) from None
trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/trax/supervised/training.py, line 865
  layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/mikado_classes.py, line 1134
  layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Dense_7 (in pure_fn):
  layer created in file [...]/mikado_classes.py, line 1133
  layer input shapes: ShapeDtype{shape:(64, 512), dtype:float32}

  File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
    y = forward(self, x, *args, **kwargs)

  File [...]/trax/layers/core.py, line 95, in forward
    return jnp.dot(x, w) + b  # Affine map.

  File [...]/_src/numpy/lax_numpy.py, line 3498, in dot
    return lax.dot_general(a, b, (contract_dims, batch_dims), precision)

  File [...]/_src/lax/lax.py, line 674, in dot_general
    preferred_element_type=preferred_element_type)

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/ad.py, line 285, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)

  File [...]/jax/interpreters/ad.py, line 458, in standard_jvp
    val_out = primitive.bind(*primals, **params)

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 140, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 147, in default_process_primitive
    return primitive.bind(*consts, **params)

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/partial_eval.py, line 1058, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)

  File [...]/_src/lax/lax.py, line 1992, in standard_abstract_eval
    shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)

  File [...]/_src/lax/lax.py, line 3090, in _dot_general_shape_rule
    raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))

TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [50].

4

1 回答 1

0

错误不在架构中。问题是:我的输入形状不正确

目标应该是形状 (batch_size, ) 但我发送了 (batch_size, 1)。所以目标数组应该是,例如:

[1, 5, 99, 2, 1, 3, 2, 8]

但我制作了

[[1], [5], [99], [2], [1], [3], [2], [8]].
于 2021-03-15T17:28:56.707 回答