我已经成功创建了模型并希望将其导出以用于从 java 客户端进行预测,但是在使用来自 java 的预测存根调用预测时出错,因为我需要在调用预测时将序列化示例放入占位符对象中!
You must feed a value for placeholder tensor 'input_example_tensor' with dtype string and shape [?]
是否有人可以帮助我在 java 中使用 protobuff 创建张量占位符?
有如下错误 -
io.grpc.StatusRuntimeException: INVALID_ARGUMENT: You must feed a value for placeholder tensor 'input_example_tensor' with dtype string and shape [?]
[[Node: input_example_tensor = Placeholder[dtype=DT_STRING, shape=[?], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
at io.grpc.stub.ClientCalls.toStatusRuntimeException(ClientCalls.java:221)
at io.grpc.stub.ClientCalls.getUnchecked(ClientCalls.java:202)
at io.grpc.stub.ClientCalls.blockingUnaryCall(ClientCalls.java:131)
at tensorflow.serving.PredictionServiceGrpc$PredictionServiceBlockingStub.predict(PredictionServiceGrpc.java:332)
我使用的签名定义如下使用 saved_model_cli -
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['classes'] tensor_info:
dtype: DT_STRING
shape: (-1, 2)
name: dnn/head/Tile:0
outputs['scores'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: dnn/head/predictions/probabilities:0
Method name is: tensorflow/serving/classify
请在 java 中找到以下用于创建请求对象的代码 -
long start1 = System.currentTimeMillis();
HashMap<String, Feature> inputFeatureMap = new HashMap();
ByteString inputStr = null;
List<ByteString> inputList = new ArrayList<ByteString>();
HashMap<String, Object> inputData = new HashMap<String, Object>();
inputData.put("bid", Float.parseFloat("-1.628"));
inputData.put("day_of_week", "6");
inputData.put("hour_of_day", "5");
inputData.put("connType", "wifi");
inputData.put("geo", "0");
inputData.put("size", "Phone");
inputData.put("cat", "arcadegame");
inputData.put("os", "7");
inputData.put("conv", Float.parseFloat("4"));
inputData.put("time", Float.parseFloat("650907"));
inputData.put("conn", Float.parseFloat("5"));
for (Map.Entry<String, Object> entry : inputData.entrySet()) {
Feature feature = null;
String featureName = entry.getKey();
Object featureValue = entry.getValue();
if (featureValue instanceof Float) {
feature = Feature.newBuilder()
.setFloatList(FloatList.newBuilder().addValue(Float.parseFloat(featureValue.toString())))
.build();
} else if (featureValue instanceof String) {
feature = Feature.newBuilder()
.setBytesList(
BytesList.newBuilder().addValue(ByteString.copyFromUtf8(featureValue.toString())))
.build();
} else if (featureValue instanceof Integer) {
feature = Feature.newBuilder()
.setInt64List(Int64List.newBuilder().addValue(Integer.parseInt(featureValue.toString())))
.build();
}
if (feature != null) {
inputFeatureMap.put(featureName, feature);
}
Features features = Features.newBuilder().putAllFeature(inputFeatureMap).build();
inputStr = Example.newBuilder().setFeatures(features).build().toByteString();
}
TensorProto.Builder asyncReBuilder = TensorProto.newBuilder();
asyncReBuilder.addStringVal(inputStr);
TensorShapeProto.Dim idsDim2 = TensorShapeProto.Dim.newBuilder().setSize(inputList.size()).build();
TensorShapeProto idsShape2 = TensorShapeProto.newBuilder().addDim(idsDim2).build();
asyncReBuilder.setDtype(DataType.DT_STRING).setTensorShape(idsShape2);
TensorProto allReqAsyncProto = asyncReBuilder.build();
TensorProto proto = allReqAsyncProto;
// Generate gRPC request
com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder().setValue(modelVersion)
.build();
Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder().setName(modelName).setVersion(version).build();
Predict.PredictRequest request = Predict.PredictRequest.newBuilder().setModelSpec(modelSpec)
.putAllInputs(ImmutableMap.of("inputs", proto)).build();
// Request gRPC server
PredictResponse response;
try {
response = blockingStub.predict(request);
long end = System.currentTimeMillis();
long diff = end - start1;
System.out.println("diff:"+ diff);
System.out.println("Response output count is - "+response.getOutputsCount());
System.out.println("outputs are: - " + response.getOutputs());
System.out.println("*********************************************");
// response = asyncStub.predict(request);
System.out.println("PREDICTION COMPLETE>>>>>>");
} catch (StatusRuntimeException e) {
e.printStackTrace();
return;
}
注意:我已经使用并使用以下导出函数()成功导出了模型 -
def _make_serving_input_fn(working_dir):
"""Creates an input function reading from raw data.
Args:
working_dir: Directory to read transformed metadata from.
Returns:
The serving input function.
"""
raw_feature_spec = RAW_DATA_METADATA.schema.as_feature_spec()
# Remove label since it is not available during serving.
raw_feature_spec.pop(LABEL_KEY)
def serving_input_fn():
raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
raw_feature_spec)
raw_features, _, default_inputs = raw_input_fn()
# Apply the transform function that was used to generate the materialized
# data.
_, transformed_features = (
saved_transform_io.partially_apply_saved_transform(
os.path.join(working_dir, transform_fn_io.TRANSFORM_FN_DIR),
raw_features))
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[None] )
receiver_tensors = {'examples': serialized_tf_example}
return tf.estimator.export.ServingInputReceiver(transformed_features, receiver_tensors)
return serving_input_fn