在使用 python 3.7 处理 tensorflow 服务时,出现上述错误。tensorflow 模型得到了正确的训练,但是当我尝试将模型信息添加到 tensorflow 模型服务器的配置文件中时,我收到了上述错误。
保存模型的代码如下:
checkpoint_versioned = os.path.join(checkpoint_prefix, "1")
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(checkpoint_versioned)
tensor_info_input_1 = tf.compat.v1.saved_model.utils.build_tensor_info(self.sess.graph.get_operation_by_name("input_x").outputs[0])
tensor_info_input_2 = tf.compat.v1.saved_model.utils.build_tensor_info(self.sess.graph.get_operation_by_name("dropout_keep_prob").outputs[0])
tensor_info_output = tf.compat.v1.saved_model.utils.build_tensor_info(self.sess.graph.get_operation_by_name("output/proba").outputs[0])
prediction_signature = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(inputs={
'input_x': tensor_info_input_1,
'dropout_keep_prob':tensor_info_input_2
},
outputs={
'proba': tensor_info_output
},
method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables(sess=self.sess, tags=[tf.compat.v1.saved_model.tag_constants.SERVING], signature_def_map={'predict': prediction_signature}, saver = self.saver)
path = builder.save()
将训练好的模型添加到配置文件的代码:
config_ini = self.read_config_file()
channel = grpc.insecure_channel(host)
stub = model_service_pb2_grpc.ModelServiceStub(channel)
request = model_management_pb2.ReloadConfigRequest()
model_server_config = model_server_config_pb2.ModelServerConfig()
model_server_config = text_format.Parse(text=config_ini, message=model_server_config)
model_server_config = self.delete_from_config(model_server_config, name)
config_list = model_server_config_pb2.ModelConfigList()
# Create a config to add to the list of served models
one_config = config_list.config.add()
one_config.name = name
one_config.base_path = base_path
one_config.model_platform = model_platform
model_server_config.model_config_list.MergeFrom(config_list)
request.config.CopyFrom(model_server_config)
response = stub.HandleReloadConfigRequest(request, 10)
我使用的python包版本是:python = 3.7.5 tensorflow = '2.0.0-rc0' tensorflow_model_server = TensorFlow ModelServer: 2.1.0-rc1+dev.sha.d83512c TensorFlow Library: 2.1.0
运行 tensorflow 模型服务器的命令:tensorflow_model_server --port=9000 --model_config_file="/home/swapnil/models.conf"