我已经训练了一个现成的 Transformer()。
现在我想使用编码器来构建分类器。为此,我只想使用第一个令牌的输出(bert-style cls-token-result)并通过密集层运行它。
我所做的:
tl.Serial(encoder, tl.Fn('pooler', lambda x: (x[:, 0, :])), tl.Dense(7))
形状:
编码
器
给我形状(64、50、512)64 = batch_size,
50 = seq_len,
512 = model_dim
pooler 为我提供了符合预期和期望的形状(64, 512)。
密集层应该为每个批次成员采用 512 个维度并分类超过 7 个类。但我猜 trax/jax 仍然希望它的长度为 seq_len (50)。
TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [50].
我想念什么?
完整追溯:
Traceback (most recent call last):
File "mikado_classes.py", line 2054, in <module>
app.run(main)
File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "mikado_classes.py", line 1153, in main
loop_neu.run(2)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 361, in run
loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 483, in _run_one_step
batch, rng, step=step, learning_rate=learning_rate
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 134, in one_step
(weights, self._slots), step, self._opt_params, batch, state, rng)
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 173, in single_device_update_fn
batch, weights, state, rng)
File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py", line 549, in pure_fn
self._caller, signature(x), trace) from None
jax._src.traceback_util.FilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py, line 865
layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/mikado_classes.py, line 1134
layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Dense_7 (in pure_fn):
layer created in file [...]/mikado_classes.py, line 1133
layer input shapes: ShapeDtype{shape:(64, 512), dtype:float32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 95, in forward
return jnp.dot(x, w) + b # Affine map.
File [...]/_src/numpy/lax_numpy.py, line 3498, in dot
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
File [...]/_src/lax/lax.py, line 674, in dot_general
preferred_element_type=preferred_element_type)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/ad.py, line 285, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File [...]/jax/interpreters/ad.py, line 458, in standard_jvp
val_out = primitive.bind(*primals, **params)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 140, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 147, in default_process_primitive
return primitive.bind(*consts, **params)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 1058, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 1992, in standard_abstract_eval
shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
File [...]/_src/lax/lax.py, line 3090, in _dot_general_shape_rule
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [50].
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "mikado_classes.py", line 2054, in <module>
app.run(main)
File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "mikado_classes.py", line 1153, in main
loop_neu.run(2)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 361, in run
loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 483, in _run_one_step
batch, rng, step=step, learning_rate=learning_rate
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 134, in one_step
(weights, self._slots), step, self._opt_params, batch, state, rng)
File "/root/.local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 398, in f_jitted
return cpp_jitted_f(context, *args, **kwargs)
File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 295, in cache_miss
donated_invars=donated_invars)
File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 1275, in bind
return call_bind(self, fun, *args, **params)
File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 1266, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 1278, in process
return trace.process_call(self, fun, tracers, params)
File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 631, in process_call
return primitive.impl(f, *tracers, **params)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
*unsafe_map(arg_spec, args))
File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py", line 260, in memoized_fun
ans = call(fun, *args)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 656, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1216, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1196, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 173, in single_device_update_fn
batch, weights, state, rng)
File "/root/.local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 810, in value_and_grad_f
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 1918, in _vjp
out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 116, in vjp
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 101, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 506, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py", line 549, in pure_fn
self._caller, signature(x), trace) from None
trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py, line 865
layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/mikado_classes.py, line 1134
layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Dense_7 (in pure_fn):
layer created in file [...]/mikado_classes.py, line 1133
layer input shapes: ShapeDtype{shape:(64, 512), dtype:float32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 95, in forward
return jnp.dot(x, w) + b # Affine map.
File [...]/_src/numpy/lax_numpy.py, line 3498, in dot
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
File [...]/_src/lax/lax.py, line 674, in dot_general
preferred_element_type=preferred_element_type)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/ad.py, line 285, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File [...]/jax/interpreters/ad.py, line 458, in standard_jvp
val_out = primitive.bind(*primals, **params)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 140, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 147, in default_process_primitive
return primitive.bind(*consts, **params)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 1058, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 1992, in standard_abstract_eval
shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
File [...]/_src/lax/lax.py, line 3090, in _dot_general_shape_rule
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [50].