0

我正在为 colab 中的 Kaggle 应用于 word mnist 数据集的 OCR 模型工作。我受到来自 ocr 验证码的模型的启发,该模型具有由 A_K_Nain 在站点中托管的 Keras 示例中编写的 LSTM 和 CTC 层:https ://keras.io/examples/vision/captcha_ocr/ 我想保存模型但是当我尝试保存时加载它以对看不见的数据进行预测。我收到未知 CTClayer 的错误。ctclaer 不是在模型内部而是在模型外部定义的问题,所以当我尝试加载模型时,我会遇到错误。我找到了使用自定义模型的解决方案,但对我没有任何作用。如何保存托管在以下站点中的模型:https ://keras.io/examples/vision/captcha_ocr/

4

2 回答 2

0

CTC 层不用于进行预测,因此您可以在没有 CTC 层的情况下像这样保存:-

saving_model = keras.models.Model(model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
saving_model.summary()
saving_model.save("model_tf")

除此之外,您必须进行一些更改才能使此代码在变量中工作:-

max_length = max([len(label) for label in labels])
outfile = open("max_length",'wb')
pickle.dump(max_length,outfile)
outfile.close()
import string
chars = string.printable
chars = chars[:-5]
characters = [c for c in chars]

这将给出一组定义的字符,这将有助于预测,因此在预测部分你必须做:-

infile = open("max_length",'rb')
max_length = pickle.load(infile)
infile.close()

import string
chars = string.printable
chars = chars[:-5]
characters = [c for c in chars]

# Mapping characters to integers
char_to_num = layers.experimental.preprocessing.StringLookup(
    vocabulary=characters, mask_token=None
)

# Mapping integers back to original characters
num_to_char = layers.experimental.preprocessing.StringLookup(
    vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
prediction_model = tf.keras.models.load_model('model_tf')

然后进一步进行。

于 2021-08-13T08:31:03.877 回答
0

这是我们如何使用作者 A_K_Nain 代码预测新图像的方法。从相同的代码加载相关函数。

test_img_path =['/path/to/test/image/117011.png']

validation_dataset = tf.data.Dataset.from_tensor_slices((test_img_path[0:1], ['']))
validation_dataset = (
    validation_dataset.map(
        encode_single_sample, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .batch(batch_size)
    .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
)

for batch in validation_dataset.take(1):
    #print(batch['image'])
    
    preds = reconstructed_model.predict(batch['image']) # reconstructed_model is saved trained model
    pred_texts = decode_batch_predictions(preds)

print(pred_texts)
于 2021-12-07T14:38:48.170 回答