2

我正在构建一个 tfx 管道并使用 tensorflow 服务来服务我的模型。我用 . 保存签名model.save(...)

到目前为止,我能够在预测之前使用转换层来转换特征tf_transform_output.transform_features_layer()(参见下面的代码)。

但是,我想知道如何检测输入数据中的异常?例如,我不想预测与之前训练特征的分布相距太远的输入值。

tfdv库提供了类似的功能,generate_statistics_from_[csv|dataframe|tfrecord]但我找不到任何好的示例来生成序列化tf.Examples 的统计信息(或未保存在文件中的内容,如 csv、tfrecords 等)。

我知道文档中的以下示例

   import tensorflow_data_validation as tfdv
   import tfx_bsl
   import pyarrow as pa
   decoder = tfx_bsl.coders.example_coder.ExamplesToRecordBatchDecoder()
   example = decoder.DecodeBatch([serialized_tfexample])
   options = tfdv.StatsOptions(schema=schema)
   anomalies = tfdv.validate_instance(example, options)

但是在这个例子serialized_tfexample中是一个字符串,而在我下面的代码中,参数serialized_tf_examples是一个字符串的张量。

对不起,如果这是一个明显的问题。我花了一整天的时间寻找解决方案,但没有成功。也许我把这一切都弄错了。也许这不是放置验证的正确位置。所以我更笼统的问题实际上是:当您在生产中提供通过 tfx 管道创建的模型时,如何在预测之前验证传入的输入数据?我很感谢任何引导到正确方向的方法。

这是我要添加验证的代码:

...

tf_transform_output = tft.TFTransformOutput(...)
model.tft_layer = tf_transform_output.transform_features_layer()

@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
])
def serve_tf_examples_fn(serialized_tf_examples):

    #### How can I generate stats and validate serialized_tf_examples? ###
    #### Is this the right place? ###

    feature_spec = tf_transform_output.raw_feature_spec()
    feature_spec.pop(TARGET_LABEL)
    parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)

    transformed_features = model.tft_layer(parsed_features)

    return model(transformed_features)

...

model.save(serving_model_dir,
           save_format='tf',
           signatures={
               'serving_default': serve_tf_examples_fn
           })
4

0 回答 0