11

我使用来自 Keras 的预训练 VGG-16 模型。

到目前为止,我的工作源代码是这样的:

from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.applications.vgg16 import decode_predictions

model = VGG16()

print(model.summary())

image = load_img('./pictures/door.jpg', target_size=(224, 224))
image = img_to_array(image)  #output Numpy-array

image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))

image = preprocess_input(image)
yhat = model.predict(image)

label = decode_predictions(yhat)
label = label[0][0]

print('%s (%.2f%%)' % (label[1], label[2]*100))

我发现该模型接受了 1000 个课程的训练。是否有可能获得该模型所训练的类的列表?打印出所有预测标签不是一种选择,因为只有 5 个返回。

提前致谢

4

4 回答 4

9

您可以使用 decode_predictions 并在参数中传递类的总数top=1000(只有其默认值为 5)。

或者您可以查看 Keras 如何在内部执行此操作:它下载文件imagenet_class_index.json(并且通常将其缓存在~/.keras/models/. 这是一个包含所有类标签的简单 json 文件。

于 2017-11-24T14:14:17.160 回答
1

我想如果你做这样的事情:

vgg16 = keras.applications.vgg16.VGG16(include_top=True,
                               weights='imagenet',
                               input_tensor=None,
                               input_shape=None,
                               pooling=None,
                               classes=1000)

vgg16.decode_predictions(np.arange(1000), top=1000)

用你的预测数组替换 np.arange(1000)。到目前为止未经测试的代码。

我认为在这里链接到培训标签:http: //image-net.org/challenges/LSVRC/2014/browse-synsets

于 2018-01-15T02:15:50.527 回答
0

如果您稍微编辑代码,您可以获得您提供的示例的所有热门预测的列表。Tensorflowdecode_predictions返回列表类预测元组的列表。因此,首先,将 top=1000 参数添加为 @YSelf 推荐,label = decode_predictions(yhat, top=1000)然后更改label = label[0][0]label = label[0][:]选择所有预测。标签看起来像这样:

[('n04252225', 'snowplow', 0.4144803),
('n03796401', 'moving_van', 0.09205707),
('n04461696', 'tow_truck', 0.08912289),
('n03930630', 'pickup', 0.07173037),
('n04467665', 'trailer_truck', 0.048759833),
('n02930766', 'cab', 0.043586567),
('n04037443', 'racer', 0.036957625),....)]

从这里您需要进行元组解包。如果您只想获得 1000 个课程的列表,您可以调用[y for (x,y,z) in label],您将获得所有 1000 个课程的列表。输出如下所示:

['snowplow',
'moving_van',
'tow_truck',
'pickup',
'trailer_truck',
'cab',
'racer',....]
于 2018-07-30T23:57:33.590 回答
0

这一行将打印出所有类名和索引: decode_predictions(np.expand_dims(np.arange(1000), 0), top=1000)

于 2020-11-21T19:50:57.163 回答