3
class CTCLoss(keras.losses.Loss):

 def __init__(self, logits_time_major=False, blank_index=-1, 
              reduction=keras.losses.Reduction.AUTO, name='ctc_loss'):
     super().__init__(reduction=reduction, name=name)
     self.logits_time_major = logits_time_major
     self.blank_index = blank_index

 def call(self, y_true, y_pred):
     y_true = tf.cast(y_true, tf.int32)
     y_true = tf.reshape(y_true,  [batch_size, max_label_seq_length])
     y_pred = tf.reshape(y_pred, [frames, batch_size, num_labels])
     loss = tf.nn.ctc_loss(
         labels=y_true,
         logits=y_pred,
         label_length=4480,
         logit_length=4480)
     return tf.reduce_mean(loss)

model = Sequential()
model.add(Bidirectional(LSTM(35, input_shape=X_train.shape, return_sequences=True)))
# didn't add the hidden layers in this code snippet. 
model.add(Flatten())
model.add(Dense((4480), activation='softmax'))

model.compile(optimizer='adam',
           loss=CTCLoss(),
           metrics=['accuracy'])

我正在尝试解决在线手写识别问题,并尝试使用 CTC 损失函数。我尝试在上面的代码中使用这个类作为我的 CTC 损失函数。但是关于被抛出的尺寸有一个错误。有人可以解释一下这些参数是什么吗?尤其是 [frames, batch_size, num_labels] 中的“帧”是什么意思。请让我知道我在这个特定代码中哪里出错了。我的 X_train 的形状为 (1311, 919, 3)。谢谢。

4

0 回答 0