1

我加载mnist数据集如下,

(X_train, y_train), (X_test, y_test) = mnist.load_data()

然而,由于我需要加载和训练我自己的数据集,我编写了如下的小脚本,它将给出准确的训练和测试值

def load_train(path):
X_train = []
y_train = []
print('Read train images')
for j in range(10):
    files = glob(path + "*.jpeg")
    for fl in files:
        img = get_im(fl)
        print(fl)
        X_train.append(img)
        y_train.append(j)

return np.asarray(X_train), np.asarray(y_train)

相关模型在训练时生成一个大小为 (64, 28, 28, 1) 的 numpy 数组。我从生成的图像中连接 image_batch,如下所示,

    X = np.concatenate((image_batch, generated_images))

但是我收到以下错误,

ValueError:所有输入数组必须具有相同的维数

img_batch 的大小为 (64, 28, 28) generated_images 的大小为 (64, 28, 28, 1)

如何扩展img_batchX_train 中的维度以便与 generate_images 连接?或者有没有其他方法可以加载自定义图像来代替 loadmnist?

4

2 回答 2

2

在 python 中有一个函数被调用np.expand_dims(),它可以沿参数中提供的轴扩展任何数组的维度。在您的情况下,使用img_batch = np.expand_dims(img_batch, axis=3).

另一种方法是使用reshape@Ioannis Nasios 建议的函数。img_batch = img_batch.reshape(64,28,28,1)

于 2018-08-31T10:58:35.410 回答
0
image_batch = image_batch.reshape(64, 28, 28, 1)
于 2018-08-31T10:56:32.930 回答