我正在使用变压器库来运行变压器模型(roberta-large-mnli
):
def model(n_classes):
input_ids = tf.keras.layers.Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids")
attention_mask = tf.keras.layers.Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask")
token_type_ids = tf.keras.layers.Input(shape=(MAX_LEN,), dtype=tf.int32, name="token_type_ids")
inputs_ = {'input_ids':input_ids, 'attention_mask': attention_mask, 'token_type_ids': token_type_ids}
net = TFRobertaModel.from_pretrained('roberta-large-mnli')(inputs_)['pooler_output']
classes_ = tf.keras.layers.Dense(n_classes, activation='softmax')(net)
model = tf.keras.Model(inputs=[input_ids, attention_mask, token_type_ids], outputs=classes_)
model.summary()
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
return model
这就是我开始培训过程的方式:
model.fit(x=training, steps_per_epoch=compute_steps(train['text']), epochs=EPOCHS,
validation_data=validation, validation_steps=compute_steps(dev['text']),
validation_freq=1, verbose=1, callbacks=callbacks())
training
是一个张量流数据集 ( tensorflow.python.data.ops.dataset_ops.PrefetchDataset
)。
当我运行此代码时,我收到以下错误:
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to
==================================================================================================
attention_mask (InputLayer) [(None, 256)] 0
__________________________________________________________________________________________________
input_ids (InputLayer) [(None, 256)] 0
__________________________________________________________________________________________________
token_type_ids (InputLayer) [(None, 256)] 0
__________________________________________________________________________________________________
tf_roberta_model (TFRobertaMode TFBaseModelOutputWit 355359744 attention_mask[0][0]
input_ids[0][0] 10:21 21-Jun-21
token_type_ids[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 3) 3075 tf_roberta_model[0][1]
==================================================================================================
Total params: 355,362,819
Trainable params: 355,362,819
Non-trainable params: 0
__________________________________________________________________________________________________
Texts contains 99 rows
Texts contains 95 rows
Epoch 1/100
Traceback (most recent call last):
File "n_shots.py", line 110, in <module>
validation_freq=1, verbose=1, callbacks=callbacks())
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
tmp_logs = train_function(iterator)
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
*args, **kwds))
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_
internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py
_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
AssertionError: in user code:
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:806 train_function *
return step_function(self, iterator)
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:796 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
return fn(*args, **kwargs)
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:789 run_step **
outputs = model.train_step(data)
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:747 train_step
y_pred = self(x, training=True)
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:985 __call__
outputs = call_fn(inputs, *args, **kwargs)
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:386 call
inputs, training=training, mask=mask)
/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:517 _run_internal_graph
assert x_id in tensor_dict, 'Could not compute output ' + str(x)
AssertionError: Could not compute output Tensor("dense/Softmax:0", shape=(None, 3), dtype=float32)
如果我换行:
net = TFRobertaModel.from_pretrained('roberta-large-mnli')(inputs_)['pooler_output']
至
net = TFBertModel.from_pretrained('bert-base-uncased').bert(inputs_)['pooler_output']
一切正常。