0

我想知道如何在 cifar-10 中选择特定的类。例如,我想要 7,cifar-10 中的“马”类。我写了下面的代码。但是获得的数据不是我想要的,因为它的形状错误。

请赐教具体情况。

from keras.datasets import cifar10

(X_train, Y_train), (X_test, Y_test) = cifar10.load_data()

print('X_train shape: {0}, Y_train shape: {1}'.format(X_train.shape, Y_train.shape))
X_train shape: (50000, 32, 32, 3), Y_train shape: (50000, 1)

下面的代码是错误的。

import numpy as np

filter = np.where(Y_train == 7)

X_train = X_train[filter]
Y_train = Y_train[filter]

print('X_train shape: {0}, Y_train shape: {1}'.format(X_train.shape, Y_train.shape))
X_train shape: (5000, 32, 3), Y_train shape: (5000,)

预期输出低于

X_train shape: (5000, 32, 32, 3), Y_train shape: (5000,)
4

1 回答 1

0

对于切片,请执行以下操作:

X_train = X_train[filter[0], ...]
Y_train = Y_train[filter[0], ...]

形状将是

X_train shape: (5000, 32, 32, 3), Y_train shape: (5000, 1)
于 2021-05-20T05:54:55.270 回答