我想知道如何在 cifar-10 中选择特定的类。例如,我想要 7,cifar-10 中的“马”类。我写了下面的代码。但是获得的数据不是我想要的,因为它的形状错误。
请赐教具体情况。
from keras.datasets import cifar10
(X_train, Y_train), (X_test, Y_test) = cifar10.load_data()
print('X_train shape: {0}, Y_train shape: {1}'.format(X_train.shape, Y_train.shape))
X_train shape: (50000, 32, 32, 3), Y_train shape: (50000, 1)
下面的代码是错误的。
import numpy as np
filter = np.where(Y_train == 7)
X_train = X_train[filter]
Y_train = Y_train[filter]
print('X_train shape: {0}, Y_train shape: {1}'.format(X_train.shape, Y_train.shape))
X_train shape: (5000, 32, 3), Y_train shape: (5000,)
预期输出低于
X_train shape: (5000, 32, 32, 3), Y_train shape: (5000,)