使用:
TensorFlow:2.4.1
蟒蛇:3.6.9
CUDA:11
CuDNN:8
使用 tf.data.Dataset 生成数据,如果批量大小为 1,则代码运行良好,但如果我增加批量大小,则会出现以下错误。我正在使用自定义 CTC 损失函数,因为它是语音识别系统那么如何解决这个错误。
class CTCLayer(layers.Layer):
def __init__(self, name=None, **kwargs):
super(CTCLayer, self).__init__(name=name, **kwargs)
self.loss_fn = keras.backend.ctc_batch_cost
def call(self, y_true, y_pred):
# Compute the training-time loss value and add it
# to the layer using `self.add_loss()`.
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)
#acc = self.accuracy_fn(y_true, y_pred, input_length, label_length)
#self.add_metric(cc, name="accuracy")
# At test time, just return the computed predictions
return y_pred
这是完整的错误:
Epoch 1/50
wTraceback (most recent call last):
File "ctc-asr-v2/keras_model_train.py", line 348, in <module>
batch_size=2,
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
tmp_logs = self.train_function(iterator)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 888, in _call
return self._stateless_fn(*args, **kwds)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2943, in __call__
filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 560, in call
ctx=ctx)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Cannot add tensor to the batch: number of elements does not match. Shapes are: [tensor]: [125], [batch]: [149]
[[node IteratorGetNext (defined at ctc-asr-v2/keras_model_train.py:348) ]]
(1) Invalid argument: Cannot add tensor to the batch: number of elements does not match. Shapes are: [tensor]: [125], [batch]: [149]
[[node IteratorGetNext (defined at ctc-asr-v2/keras_model_train.py:348) ]]
[[model/ctc_loss/CTCLoss/_256]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_9169]
Function call stack:
train_function -> train_function
引用此示例代码:https ://www.tensorflow.org/tutorials/audio/simple_audio