0

我正在使用变压器库来运行变压器模型(roberta-large-mnli):

def model(n_classes):
    input_ids = tf.keras.layers.Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_ids")
    attention_mask = tf.keras.layers.Input(shape=(MAX_LEN,), dtype=tf.int32, name="attention_mask")
    token_type_ids = tf.keras.layers.Input(shape=(MAX_LEN,), dtype=tf.int32, name="token_type_ids")
    inputs_ = {'input_ids':input_ids, 'attention_mask': attention_mask, 'token_type_ids': token_type_ids}

    net = TFRobertaModel.from_pretrained('roberta-large-mnli')(inputs_)['pooler_output']
    
    classes_ = tf.keras.layers.Dense(n_classes, activation='softmax')(net)
    model = tf.keras.Model(inputs=[input_ids, attention_mask, token_type_ids], outputs=classes_)
    model.summary()

    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    return model

这就是我开始培训过程的方式:

model.fit(x=training, steps_per_epoch=compute_steps(train['text']), epochs=EPOCHS,
                  validation_data=validation, validation_steps=compute_steps(dev['text']),
                  validation_freq=1, verbose=1, callbacks=callbacks())

training是一个张量流数据集 ( tensorflow.python.data.ops.dataset_ops.PrefetchDataset)。

当我运行此代码时,我收到以下错误:

__________________________________________________________________________________________________                                                    Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
attention_mask (InputLayer)     [(None, 256)]        0                                                                                      
__________________________________________________________________________________________________
input_ids (InputLayer)          [(None, 256)]        0                                                                                     
__________________________________________________________________________________________________
token_type_ids (InputLayer)     [(None, 256)]        0
__________________________________________________________________________________________________
tf_roberta_model (TFRobertaMode TFBaseModelOutputWit 355359744   attention_mask[0][0]                                                  
                                                                 input_ids[0][0]                                                                       10:21 21-Jun-21
                                                                 token_type_ids[0][0]                                                  
__________________________________________________________________________________________________
dense (Dense)                   (None, 3)            3075        tf_roberta_model[0][1]                                       
==================================================================================================
Total params: 355,362,819                                                                                                                       
Trainable params: 355,362,819                               
Non-trainable params: 0                                                                                                                          
__________________________________________________________________________________________________
Texts contains 99 rows                                                                                                            
Texts contains 95 rows                  
Epoch 1/100                                                                                                                     
Traceback (most recent call last):     
  File "n_shots.py", line 110, in <module>                                                                                      
    validation_freq=1, verbose=1, callbacks=callbacks())
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)     
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit  
    tmp_logs = train_function(iterator)                                 
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)                                                                
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)                                                  
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
    *args, **kwds))
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_
internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py
_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
AssertionError: in user code:

    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
        return step_function(self, iterator)
    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:796 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
        return fn(*args, **kwargs)
    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:789 run_step  **
        outputs = model.train_step(data)
    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:747 train_step
        y_pred = self(x, training=True)
    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:985 __call__
        outputs = call_fn(inputs, *args, **kwargs)
    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:386 call
        inputs, training=training, mask=mask)
    /home/USER_NAME/venv/tf_23/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:517 _run_internal_graph
        assert x_id in tensor_dict, 'Could not compute output ' + str(x)

    AssertionError: Could not compute output Tensor("dense/Softmax:0", shape=(None, 3), dtype=float32)  

如果我换行:

net = TFRobertaModel.from_pretrained('roberta-large-mnli')(inputs_)['pooler_output']

net = TFBertModel.from_pretrained('bert-base-uncased').bert(inputs_)['pooler_output']

一切正常。

4

0 回答 0