1

我目前正在尝试实现一个训练图像三元组的网络。为此,我改编了我在 Internet 上找到的 par 生成器:

def triplet_generator(triples, image_cache, datagens, batch_size=32):
    while True:
        # shuffle once per batch
        indices = np.random.permutation(np.arange(len(triples)))
        num_batches = len(triples) // batch_size
        for bid in range(num_batches):
            batch_indices = indices[bid * batch_size : (bid + 1) * batch_size]
            batch = [triples[i] for i in batch_indices]
            X1 = np.zeros((batch_size, 64, 64, 3))
            X2 = np.zeros((batch_size, 64, 64, 3))
            X3 = np.zeros((batch_size, 64, 64, 3))
            for i, (image_filename_l, image_filename_m, image_filename_r) in enumerate(batch):
                if datagens is None or len(datagens) == 0:
                    X1[i] = image_cache[image_filename_l]
                    X2[i] = image_cache[image_filename_m]
                    X3[i] = image_cache[image_filename_r]
                else:
                    X1[i] = datagens[0].random_transform(image_cache[image_filename_l])
                    X2[i] = datagens[1].random_transform(image_cache[image_filename_m])
                    X3[i] = datagens[2].random_transform(image_cache[image_filename_r])
            yield [np.array(X1), np.array(X2), np.array(X3)]

用于训练我的网络:

base_network = create_base_network(input_shape)
print(base_network.summary())
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
input_c = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
processed_c = base_network(input_c)

merged_vector = concatenate([processed_a, processed_b, processed_c], axis=-1, name='merged_layer')

model = Model([input_a, input_b, input_c], merged_vector)
checkpoint = ModelCheckpoint(filepath=BEST_MODEL_FILE, save_best_only=True)
# train

rms = RMSprop()
model.compile(loss=loss_desc_triplet, optimizer=rms, metrics=[accuracy])

history = model.fit_generator(train_pair_gen, 
                             steps_per_epoch=num_train_steps,
                             epochs=epochs,
                             validation_data=val_pair_gen,
                             validation_steps=num_val_steps,
                             callbacks=[checkpoint])

但是,我从中收到此错误消息:

回溯(最近一次通话最后):

文件“DeepLearningWithAugmentationWithTriplets.py”,第 256 行,在回调 = [检查点])

包装器中的文件“lib/python3.7/site-packages/keras/legacy/interfaces.py”,第 91 行,返回 func(*args, **kwargs)

文件“lib/python3.7/site-packages/keras/engine/training.py”,第 1418 行,在 fit_generator initial_epoch=initial_epoch 中)

文件“lib/python3.7/site-packages/keras/engine/training_generator.py”,第 217 行,在 fit_generator class_weight=class_weight)

文件“lib/python3.7/site-packages/keras/engine/training.py”,第 1211 行,在 train_on_batch class_weight=class_weight)

文件“lib/python3.7/site-packages/keras/engine/training.py”,第 751 行,在 _standardize_user_data exception_prefix='input')

文件“lib/python3.7/site-packages/keras/engine/training_utils.py”,第 102 行,在 standardize_input_data str(len(data)) + ' 数组:' + str(data)[:200] + '。 ..')

ValueError:检查模型输入时出错:您传递给模型的 Numpy 数组列表不是模型预期的大小。预计会看到 3 个数组,但得到了以下 1 个数组的列表: [array([[[[0.36862746, 0.36862746, 0.36862746], [0.36862746, 0.36862746, 0.36862746], [0.36862746, 0.36862746],6,6 .., [0.41176471, 0.41176471, 0.41176471...

然而,当我打印生成器的下一项时,它是三个数组的列表:

train_pair_gen = triplet_generator(triples_data_train, image_cache, datagens, batch_size)
[X1, X2, X3] = next(train_pair_gen)
print(X1.shape, X2.shape, X3.shape) --> (32, 64, 64, 3) (32, 64, 64, 3) (32, 64, 64, 3)
print("###")
print(len(next(train_pair_gen))) --> 3

我究竟做错了什么?

网络定义:

_________________________________________________________________ 层(类型)输出形状参数#
========================================= ======================== input_131 (InputLayer) (None, 64, 64, 3) 0
__________________________________________________________________ conv2d_100 (Conv2D) (None, 29, 29, 32) 4736
_________________________________________________________________ conv2d_101 (Conv2D) (无, 8, 8, 64) 73792
_________________________________________________________________ conv2d_102 (Conv2D) (无, 1, 1, 128) 204928
_________________________________________________________________ flatten_34 (展平) (无, 128) 0
==================================================== =============== 总参数:283,456 可训练参数:283,456 不可训练参数:0


4

0 回答 0