我在 tf-hub 上使用 Tensorflow 2 和 LaBSE 预训练模型(两者都不太熟悉)。(https://tfhub.dev/google/LaBSE/2)。我正在尝试使用自定义文本数据集训练多类分类器。我也在 BERT 分类器(https://www.tensorflow.org/text/tutorials/classify_text_with_bert)上关注这个例子,以了解模型是如何构建的。这是为了检查只有我可以训练和运行模型。我正在为从 csv 数据获得的输入文本使用数据集对象,如下所示,
"Sentence","label"
"sentence sample1", 0
"sentence sample2", 3
我像往常一样把它们分成 X, y 组。但是,在尝试训练模型时出现上述错误。下面是我的代码,
def build_classifier_model():
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
encoder_inputs = preprocessing_layer(text_input)
encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='LaBSE_encoder')
outputs = encoder(encoder_inputs)
net = outputs['pooled_output']
activation= tf.keras.activations.softmax#None
net = tf.keras.layers.Dropout(0.1)(net)
net = tf.keras.layers.Dense(4, activation=activation, name='classifier')(net)
return tf.keras.Model(text_input, net)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
metrics = tf.keras.metrics.SparseCategoricalAccuracy()
epochs = 5
optimizer=tf.keras.optimizers.Adam()
train_dataset = tf.data.Dataset.from_tensor_slices(( # convert to dataset objects
np.array(X_train),np.array(y_train,dtype='int32')
))
test_dataset = tf.data.Dataset.from_tensor_slices((
np.array(X_test),np.array(y_test,dtype='int32')
))
这些数据集对象的规格是,<TensorSliceDataset 形状:((),()),类型:(tf.string,tf.int32)>
classifier_model.compile(optimizer=optimizer,
loss=loss,
metrics=metrics)
his = classifier_model.fit(train_dataset, validation_data=test_dataset,
epochs=epochs, batch_size=8) #ignore that I'm using test dataset for validation dataset
最后一步给出了错误;
Epoch 1/5
WARNING:tensorflow:Model was constructed with shape (None,) for input KerasTensor(type_spec=TensorSpec(shape=(None,), dtype=tf.string, name='text'), name='text', description="created by layer 'text'"), but it was called on an input with incompatible shape ().
WARNING:tensorflow:Model was constructed with shape (None,) for input KerasTensor(type_spec=TensorSpec(shape=(None,), dtype=tf.string, name='text'), name='text', description="created by layer 'text'"), but it was called on an input with incompatible shape ().
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-141-3523a14b56f1> in <module>()
1 history = classifier_model.fit(train_dataset, validation_data=test_dataset,
----> 2 epochs=epochs, batch_size=8)
9 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
984 except Exception as e: # pylint:disable=broad-except
985 if hasattr(e, "ag_error_metadata"):
--> 986 raise e.ag_error_metadata.to_exception(e)
987 else:
988 raise
ValueError: in user code:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:855 train_function *
return step_function(self, iterator)
/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py:237 call *
result = smart_cond.smart_cond(training,
/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/load.py:670 _call_attribute **
return instance.__call__(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py:889 __call__
result = self._call(*args, **kwds)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py:924 _call
results = self._stateful_fn(*args, **kwds)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py:3022 __call__
filtered_flat_args) = self._maybe_define_function(args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py:3444 _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py:3289 _create_graph_function
capture_by_value=self._capture_by_value),
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py:999 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py:672 wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/function_deserialization.py:291 restored_function_body
"\n\n".join(signature_descriptions)))
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (3 total):
* Tensor("inputs:0", shape=(), dtype=string)
* False
* None
Keyword arguments: {}
Expected these arguments to match one of the following 4 option(s):
Option 1:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='inputs')
* True
* None
Keyword arguments: {}
Option 2:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='inputs')
* False
* None
Keyword arguments: {}
Option 3:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
* True
* None
Keyword arguments: {}
Option 4:
Positional arguments (3 total):
* TensorSpec(shape=(None,), dtype=tf.string, name='sentences')
* False
* None
Keyword arguments: {}
我认为这是提供给输入的数据集对象规范的问题,但不了解如何修复它或确切原因。即使我的数据集对象具有“tf.string”类型,我也不明白为什么它与预期的输入不兼容。我查看了现有的答案,由于我对 TF 不太熟悉,我想知道原因是什么以及如何解决这个问题。