TL;DR Google Cloud AI PlatformTFRecord
在进行批量预测时如何解包文件?
我已将经过训练的 Keras 模型部署到 Google Cloud AI Platform,但我在批量预测的文件格式方面遇到了问题。对于培训,我正在使用 atf.data.TFRecordDataset
来阅读以下列表,TFRecord
这些列表都可以正常工作。
def unpack_tfrecord(record):
parsed = tf.io.parse_example(record, {
'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32), # Input
'class': tf.io.FixedLenFeature([2], tf.int64), # One-hot classification (binary)
})
return (parsed['chunk'], parsed['class'])
files = [str(p) for p in training_chunks_path.glob('*.tfrecord')]
dataset = tf.data.TFRecordDataset(files).batch(32).map(unpack_tfrecord)
model.fit(x=dataset, epochs=train_epochs)
tf.saved_model.save(model, model_save_path)
我将保存的模型上传到 Cloud Storage 并在 AI Platform 中创建一个新模型。AI Platform 文档指出“使用 gcloud 工具进行批处理 [支持] 带有 JSON 实例字符串或 TFRecord 文件的文本文件(可能已压缩)”(https://cloud.google.com/ai-platform/prediction/docs/overview#prediction_input_data)。但是当我提供一个 TFRecord 文件时,我得到了错误:
("'utf-8' codec can't decode byte 0xa4 in position 1: invalid start byte", 8)
我的 TFRecord 文件包含一堆 Protobuf 编码的tf.train.Example
. 我没有向unpack_tfrecord
AI Platform 提供该功能,所以我想它无法正确解包是有道理的,但我知道从这里去哪里。由于数据太大,我对使用 JSON 格式不感兴趣。