4

当前的Keras Captcha OCR 模型返回一个 CTC 编码的输出,需要在推理后进行解码。

要对此进行解码,需要在推理之后作为单独的步骤运行解码实用程序函数。

preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)

解码后的效用函数使用keras.backend.ctc_decode,而后者又使用贪婪或波束搜索解码器。

# A utility function to decode the output of the network
def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_length
    ]
    # Iterate over the results and get back the text
    output_text = []
    for res in results:
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

我想使用 Keras 训练一个 Captcha OCR 模型,该模型返回解码后的 CTC 作为输出,而无需在推理后进行额外的解码步骤。

我将如何实现这一目标?

4

2 回答 2

1

实现这一点的最稳健的方法是添加一个作为模型定义的一部分调用的方法:

def CTCDecoder():
  def decoder(y_pred):
    input_shape = tf.keras.backend.shape(y_pred)
    input_length = tf.ones(shape=input_shape[0]) * tf.keras.backend.cast(
        input_shape[1], 'float32')
    unpadded = tf.keras.backend.ctc_decode(y_pred, input_length)[0][0]
    unpadded_shape = tf.keras.backend.shape(unpadded)
    padded = tf.pad(unpadded,
                    paddings=[[0, 0], [0, input_shape[1] - unpadded_shape[1]]],
                    constant_values=-1)
    return padded

return tf.keras.layers.Lambda(decoder, name='decode')

然后定义模型如下:

prediction_model = keras.models.Model(inputs=inputs, outputs=CTCDecoder()(model.output))

归功于tulasiram58827 。

此实现支持导出到 TFLite,但仅支持 float32。Quantized (int8) TFLite export 仍然抛出错误,并且是 TF 团队的公开票。

于 2021-04-22T20:04:49.890 回答
1

你的问题可以用两种方式解释。一个是:我想要一个神经网络来解决一个问题,即 CTC 解码步骤已经在网络学习的内容中。另一个是您希望有一个模型类在其中执行此 CTC 解码,而不使用外部功能函数。

我不知道第一个问题的答案。我什至无法判断它是否可行。无论如何,这听起来像是一个困难的理论问题,如果您在这里没有运气,您可能想尝试将其发布在datascience.stackexchange.com 上,这是一个更加面向理论的社区。

现在,如果您要解决的是问题的第二个工程版本,那么我可以为您提供帮助。该问题的解决方案如下:

您需要keras.models.Model使用您想要的方法对一个类进行子类化。我查看了您发布的链接中的教程,并附带了以下课程:

class ModifiedModel(keras.models.Model):
    
    # A utility function to decode the output of the network
    def decode_batch_predictions(self, pred):
        input_len = np.ones(pred.shape[0]) * pred.shape[1]
        # Use greedy search. For complex tasks, you can use beam search
        results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
            :, :max_length
        ]
        # Iterate over the results and get back the text
        output_text = []
        for res in results:
            res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
            output_text.append(res)
        return output_text

    
    def predict_texts(self, batch_images):
        preds = self.predict(batch_images)
        return self.decode_batch_predictions(preds)

你可以给它起你想要的名字,它只是为了说明的目的。定义此类后,您将替换该行

# Get the prediction model by extracting layers till the output layer
prediction_model = keras.models.Model(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
)

prediction_model = ModifiedModel(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
)

然后你可以替换行

preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)

pred_texts = prediction_model.predict_texts(batch_images)
于 2021-04-22T01:03:34.667 回答