3

在 google AutoML 中创建模型后,我们可以使用提供的 python 代码进行预测。这是代码:

import sys

from google.cloud import automl_v1beta1
from google.cloud.automl_v1beta1.proto import service_pb2


def get_prediction(content, project_id, model_id):
  prediction_client = automl_v1beta1.PredictionServiceClient()

  name = 'projects/{}/locations/us-central1/models/{}'.format(project_id, model_id)
  payload = {'image': {'image_bytes': content }}
  params = {}
  request = prediction_client.predict(name, payload, params)
  return request  # waits till request is returned

if __name__ == '__main__':
  file_path = sys.argv[1]
  project_id = sys.argv[2]
  model_id = sys.argv[3]

  with open(file_path, 'rb') as ff:
    content = ff.read()

  print get_prediction(content, project_id,  model_id)

我意识到它只会打印得分高于 threshold 的检测结果value = 0.5。示例输出:

payload {
  classification {
    score: 0.562688529491
  }
  display_name: "dog"
}

如何打印分数低于阈值 0.5 的其他检测结果(例如将阈值更改为 0.3)?

4

1 回答 1

4

请参阅此处的 api 文档

参数

具有字符串属性的对象

附加域特定参数,任何字符串的长度不得超过 25000 个字符。

对于图像分类:

score_threshold - (float) 一个从 0.0 到 1.0 的值。当模型对图像进行预测时,它只会产生至少具有此置信度分数阈值的结果。默认值为 0.5。

proto中字段的实际描述是

map<string,string> params;

因此,您将更改已设置为空字典的 params 变量。params将变量更改为 :params = {"score_threshold": "0.3"}将起作用。

于 2019-04-12T06:26:57.243 回答