我已经使用以下代码成功导出了SavedModel格式的seq2seq模型
source_tokens_ph = tf.placeholder(dtype=tf.string, shape=(1, None))
source_len_ph = tf.placeholder(dtype=tf.int32, shape=(1,))
features_serve = {
"source_tokens": source_tokens_ph,
"source_len": source_len_ph
}
experiment = PatchedExperiment(
...
export_strategies = [saved_model_export_utils.make_export_strategy(serving_input_fn = build_default_serving_input_fn(features_serve))]
)
saved_model_cli 显示导出文件中存在以下 SignatureDef
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['default_input_alternative:default_output_alternative']:
The given SavedModel SignatureDef contains the following input(s):
inputs['source_ids'] tensor_info:
dtype: DT_INT64
shape: (-1, -1)
name: model/att_seq2seq/hash_table_1_Lookup:0
inputs['source_len'] tensor_info:
dtype: DT_INT32
shape: (-1)
name: model/att_seq2seq/Minimum:0
inputs['source_tokens'] tensor_info:
dtype: DT_STRING
shape: (-1, -1)
name: model/att_seq2seq/strided_slice:0
The given SavedModel SignatureDef contains the following output(s):
outputs['attention_context'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1, 512)
name: model/att_seq2seq/transpose_4:0
outputs['attention_scores'] tensor_info:
dtype: DT_FLOAT
shape: unknown_rank
name: model/att_seq2seq/transpose_2:0
outputs['cell_output'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1, 256)
name: model/att_seq2seq/transpose_1:0
outputs['features.source_ids'] tensor_info:
dtype: DT_INT64
shape: (-1, -1)
name: model/att_seq2seq/hash_table_1_Lookup:0
outputs['features.source_len'] tensor_info:
dtype: DT_INT32
shape: (-1)
name: model/att_seq2seq/Minimum:0
outputs['features.source_tokens'] tensor_info:
dtype: DT_STRING
shape: (-1, -1)
name: model/att_seq2seq/strided_slice:0
outputs['logits'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1, 42)
name: model/att_seq2seq/transpose:0
outputs['predicted_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, -1)
name: model/att_seq2seq/transpose_3:0
outputs['predicted_tokens'] tensor_info:
dtype: DT_STRING
shape: (-1, -1)
name: model/att_seq2seq/hash_table_3_Lookup:0
Method name is: tensorflow/serving/predict
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['source_ids'] tensor_info:
dtype: DT_INT64
shape: (-1, -1)
name: model/att_seq2seq/hash_table_1_Lookup:0
inputs['source_len'] tensor_info:
dtype: DT_INT32
shape: (-1)
name: model/att_seq2seq/Minimum:0
inputs['source_tokens'] tensor_info:
dtype: DT_STRING
shape: (-1, -1)
name: model/att_seq2seq/strided_slice:0
The given SavedModel SignatureDef contains the following output(s):
outputs['attention_context'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1, 512)
name: model/att_seq2seq/transpose_4:0
outputs['attention_scores'] tensor_info:
dtype: DT_FLOAT
shape: unknown_rank
name: model/att_seq2seq/transpose_2:0
outputs['cell_output'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1, 256)
name: model/att_seq2seq/transpose_1:0
outputs['features.source_ids'] tensor_info:
dtype: DT_INT64
shape: (-1, -1)
name: model/att_seq2seq/hash_table_1_Lookup:0
outputs['features.source_len'] tensor_info:
dtype: DT_INT32
shape: (-1)
name: model/att_seq2seq/Minimum:0
outputs['features.source_tokens'] tensor_info:
dtype: DT_STRING
shape: (-1, -1)
name: model/att_seq2seq/strided_slice:0
outputs['logits'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1, 42)
name: model/att_seq2seq/transpose:0
outputs['predicted_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, -1)
name: model/att_seq2seq/transpose_3:0
outputs['predicted_tokens'] tensor_info:
dtype: DT_STRING
shape: (-1, -1)
name: model/att_seq2seq/hash_table_3_Lookup:0
Method name is: tensorflow/serving/predict
我用 Tensorflow Serving 连接了模型并发送了以下请求,
{
"inputs": {
"source_tokens": "[['DATE','OF','BIRT']]",
"source_ids": [],
"source_len": [3]
}
}
但是,它返回结果为,
{ "error": "model/att_seq2seq/Minimum:0 既被输入又被提取。" }
在提到错误后,我可以看到问题可能会出现,因为相同的张量正在被馈送和获取。
分析 SignatureDef 显示,model/att_seq2seq/Minimum:0
属于inputs['source_len']
和outputs['features.source_len']
我该如何解决这个问题?有可能不获取outputs['features.source_len']
吗?
如何手动将SignatureDefs分配给此 repo 中使用的 Experiment API?