0

例如,我使用自己的数据集微调了一个 VGG 网络,只有 2 个标签foobar. 我已通过此链接使用示例将图像转换为 tf.record :

labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

我将基于这个新模型构建一个 API 来预测图像,我的问题是:是否有任何正式的方法可以从检查点文件或数据集中获取标签字符串(例如predict_image("abc.png")返回foo字符串)?因为我不知道 logits 层中哪个节点代表 label foo,哪个节点代表bar

我试过四处寻找,但没有帮助,我仍然是一个 tensorflow 新手。

4

1 回答 1

0

模型(以及检查点文件)没有每个类的名称。它只有一定数量的输出神经元,第一个对应第一类,第二个对应第二类,以此类推。

如果您想知道哪个是哪个,请查看此行创建的标签文件(很可能命名为 labels.txt):

dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

或者,您可以检查labels_to_class_names字典的内容:

In [1]: class_names=['aaa', 'bbb', 'ccc']

In [2]: labels_to_class_names = dict(zip(range(len(class_names)), class_names))

In [3]: labels_to_class_names
Out[3]: {0: 'aaa', 1: 'bbb', 2: 'ccc'}

-> 模型输出中索引 0 处的值 = 类 'aaa' 等

于 2018-06-28T16:06:48.163 回答