例如,我使用自己的数据集微调了一个 VGG 网络,只有 2 个标签foo
和bar
. 我已通过此链接使用示例将图像转换为 tf.record :
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
我将基于这个新模型构建一个 API 来预测图像,我的问题是:是否有任何正式的方法可以从检查点文件或数据集中获取标签字符串(例如predict_image("abc.png")
返回foo
字符串)?因为我不知道 logits 层中哪个节点代表 label foo
,哪个节点代表bar
我试过四处寻找,但没有帮助,我仍然是一个 tensorflow 新手。