我正在构建一个 tfx 管道并使用 tensorflow 服务来服务我的模型。我用 . 保存签名model.save(...)
。
到目前为止,我能够在预测之前使用转换层来转换特征tf_transform_output.transform_features_layer()
(参见下面的代码)。
但是,我想知道如何检测输入数据中的异常?例如,我不想预测与之前训练特征的分布相距太远的输入值。
该tfdv
库提供了类似的功能,generate_statistics_from_[csv|dataframe|tfrecord]
但我找不到任何好的示例来生成序列化tf.Example
s 的统计信息(或未保存在文件中的内容,如 csv、tfrecords 等)。
我知道文档中的以下示例:
import tensorflow_data_validation as tfdv
import tfx_bsl
import pyarrow as pa
decoder = tfx_bsl.coders.example_coder.ExamplesToRecordBatchDecoder()
example = decoder.DecodeBatch([serialized_tfexample])
options = tfdv.StatsOptions(schema=schema)
anomalies = tfdv.validate_instance(example, options)
但是在这个例子serialized_tfexample
中是一个字符串,而在我下面的代码中,参数serialized_tf_examples
是一个字符串的张量。
对不起,如果这是一个明显的问题。我花了一整天的时间寻找解决方案,但没有成功。也许我把这一切都弄错了。也许这不是放置验证的正确位置。所以我更笼统的问题实际上是:当您在生产中提供通过 tfx 管道创建的模型时,如何在预测之前验证传入的输入数据?我很感谢任何引导到正确方向的方法。
这是我要添加验证的代码:
...
tf_transform_output = tft.TFTransformOutput(...)
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
])
def serve_tf_examples_fn(serialized_tf_examples):
#### How can I generate stats and validate serialized_tf_examples? ###
#### Is this the right place? ###
feature_spec = tf_transform_output.raw_feature_spec()
feature_spec.pop(TARGET_LABEL)
parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
...
model.save(serving_model_dir,
save_format='tf',
signatures={
'serving_default': serve_tf_examples_fn
})