0

我已经在自定义数据集上为单个对象检测任务训练了一个移动 SSD V2 模型,并将其转换为 tflite。当我使用解释器加载 .tflite 模型进行测试并使用 : 获取输入详细信息input_details = model.get_input_details()时,它会输出

[{'name': 'normalized_input_image_tensor',
'index': 272,
'shape': array([  1, 300, 300,   3], dtype=int32),
'dtype': numpy.uint8,
'quantization': (0.0078125, 128)}]

我知道“300x300”是图像的高度和宽度,“3”用于 RGB 通道,但第一个元素(“1”)指的是什么?

4

1 回答 1

0

形状:[Batch_size, height, width, channel]

如果要更改该大小,则需要在转换为 pb 文件之前对其进行设置,例如使用export_inference_graph

于 2020-01-02T07:00:49.353 回答