我有一个模型,由 CNN、RNN 和输出层组成。我拥有的数据是图像和它的转录。转录被填充到 9 个字符的长度。对于 CTC 损失,我遵循了 keras ocr 示例代码,如下所示:
class CTCLayer(layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = keras.backend.ctc_batch_cost
def call(self, y_true, y_pred):
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)
return y_pred
现在这是我实施它的方式:
#l is the number possible of classes / characters
labels = layers.Input(shape=(9,), dtype="float32")
outputs = layers.Dense(l+1, activation='softmax',name='output')(lstm)
output = CTCLayer()(labels,outputs)
model = Model(inputs = [input_layer,labels],outputs=output)
model = model.compile(optimizer = optimizers.Adam(0.01))
model.fit([x_train,y_train],y_train,validation_split = 0.2, epochs = 100)
一旦运行 model.fit 开始发生一些奇怪的事情,我得到了一个 inf 训练损失,但一个大约 20 的验证损失。我查看了可能导致它的原因并遇到了这篇文章。接受的答案如下:
绝对是导致问题的输入的序列长度。显然,序列长度应该比地面实况长度大一点。
这是什么意思,我需要如何更改我的代码才能解决我遇到的问题?