0

我在 AWS SageMaker 上部署了一个 TensorFlow 模型,我希望能够使用 csv 文件作为调用主体来调用它。该文档说有关创建serving_input_function如下所示的内容:

def serving_input_fn(hyperparameters):
  # Logic to the following:
  # 1. Defines placeholders that TensorFlow serving will feed with inference requests
  # 2. Preprocess input data
  # 3. Returns a tf.estimator.export.ServingInputReceiver or tf.estimator.export.TensorServingInputReceiver,
  # which packages the placeholders and the resulting feature Tensors together.

在第 2 步中,它说预处理输入数据,我如何获取输入数据的句柄来处理它们?

4

2 回答 2

2

我有同样的问题,但我想处理 jpeg 请求。

准备好后,您model_data可以使用以下几行来部署它。

from sagemaker.tensorflow.model import TensorFlowModel
sagemaker_model = TensorFlowModel(
            model_data = 's3://path/to/model/model.tar.gz',
            role = role,
            framework_version = '1.12',
            entry_point = 'train.py',
            source_dir='my_src',
            env={'SAGEMAKER_REQUIREMENTS': 'requirements.txt'}
)

predictor = sagemaker_model.deploy(
    initial_instance_count=1,
    instance_type='ml.m4.xlarge', 
    endpoint_name='resnet-tensorflow-classifier'
)

你的笔记本应该有一个my_src目录,其中包含一个文件train.py和一个requirements.txt文件。该train.py文件应该input_fn定义一个函数。对我来说,该函数处理图像/jpeg 内容,但模式是相同的。

import io
import numpy as np
from PIL import Image
from keras.applications.resnet50 import preprocess_input
from keras.preprocessing import image

JPEG_CONTENT_TYPE = 'image/jpeg'
CSV_CONTENT_TYPE = 'text/csv'

# Deserialize the Invoke request body into an object we can perform prediction on
def input_fn(request_body, content_type=JPEG_CONTENT_TYPE):
    # process an image uploaded to the endpoint
    if content_type == JPEG_CONTENT_TYPE:
        img = Image.open(io.BytesIO(request_body)).resize((300, 300))
        img_array = np.array(img)
        expanded_img_array = np.expand_dims(img_array, axis=0)
        x = preprocess_input(expanded_img_array)
        return x

    # you would have something like this:
    if content_type == CSV_CONTENT_TYPE:
        # handle input 
        return handled_input

    else: 
        raise errors.UnsupportedFormatError(content_type)

如果您的train.py代码导入了一些模块,则必须提供requirements.txt定义这些依赖项(这是我在文档中找不到的部分)。

希望这对将来的某人有所帮助。

于 2019-05-24T22:35:21.627 回答
0

您可以通过添加一个 input_fn 来预处理输入数据,该 input_fn 将在您每次调用和端点时被调用。它接收输入数据和数据的内容类型。

def input_fn(data, content_type):
    // do some data preprocessing.
    return preprocessed_data

本文更深入地解释它: https ://docs.aws.amazon.com/sagemaker/latest/dg/tf-training-inference-code-template.html

于 2018-04-13T01:50:27.780 回答