我目前正在尝试实现一个训练图像三元组的网络。为此,我改编了我在 Internet 上找到的 par 生成器:
def triplet_generator(triples, image_cache, datagens, batch_size=32):
while True:
# shuffle once per batch
indices = np.random.permutation(np.arange(len(triples)))
num_batches = len(triples) // batch_size
for bid in range(num_batches):
batch_indices = indices[bid * batch_size : (bid + 1) * batch_size]
batch = [triples[i] for i in batch_indices]
X1 = np.zeros((batch_size, 64, 64, 3))
X2 = np.zeros((batch_size, 64, 64, 3))
X3 = np.zeros((batch_size, 64, 64, 3))
for i, (image_filename_l, image_filename_m, image_filename_r) in enumerate(batch):
if datagens is None or len(datagens) == 0:
X1[i] = image_cache[image_filename_l]
X2[i] = image_cache[image_filename_m]
X3[i] = image_cache[image_filename_r]
else:
X1[i] = datagens[0].random_transform(image_cache[image_filename_l])
X2[i] = datagens[1].random_transform(image_cache[image_filename_m])
X3[i] = datagens[2].random_transform(image_cache[image_filename_r])
yield [np.array(X1), np.array(X2), np.array(X3)]
用于训练我的网络:
base_network = create_base_network(input_shape)
print(base_network.summary())
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
input_c = Input(shape=input_shape)
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
processed_c = base_network(input_c)
merged_vector = concatenate([processed_a, processed_b, processed_c], axis=-1, name='merged_layer')
model = Model([input_a, input_b, input_c], merged_vector)
checkpoint = ModelCheckpoint(filepath=BEST_MODEL_FILE, save_best_only=True)
# train
rms = RMSprop()
model.compile(loss=loss_desc_triplet, optimizer=rms, metrics=[accuracy])
history = model.fit_generator(train_pair_gen,
steps_per_epoch=num_train_steps,
epochs=epochs,
validation_data=val_pair_gen,
validation_steps=num_val_steps,
callbacks=[checkpoint])
但是,我从中收到此错误消息:
回溯(最近一次通话最后):
文件“DeepLearningWithAugmentationWithTriplets.py”,第 256 行,在回调 = [检查点])
包装器中的文件“lib/python3.7/site-packages/keras/legacy/interfaces.py”,第 91 行,返回 func(*args, **kwargs)
文件“lib/python3.7/site-packages/keras/engine/training.py”,第 1418 行,在 fit_generator initial_epoch=initial_epoch 中)
文件“lib/python3.7/site-packages/keras/engine/training_generator.py”,第 217 行,在 fit_generator class_weight=class_weight)
文件“lib/python3.7/site-packages/keras/engine/training.py”,第 1211 行,在 train_on_batch class_weight=class_weight)
文件“lib/python3.7/site-packages/keras/engine/training.py”,第 751 行,在 _standardize_user_data exception_prefix='input')
文件“lib/python3.7/site-packages/keras/engine/training_utils.py”,第 102 行,在 standardize_input_data str(len(data)) + ' 数组:' + str(data)[:200] + '。 ..')
ValueError:检查模型输入时出错:您传递给模型的 Numpy 数组列表不是模型预期的大小。预计会看到 3 个数组,但得到了以下 1 个数组的列表: [array([[[[0.36862746, 0.36862746, 0.36862746], [0.36862746, 0.36862746, 0.36862746], [0.36862746, 0.36862746],6,6 .., [0.41176471, 0.41176471, 0.41176471...
然而,当我打印生成器的下一项时,它是三个数组的列表:
train_pair_gen = triplet_generator(triples_data_train, image_cache, datagens, batch_size)
[X1, X2, X3] = next(train_pair_gen)
print(X1.shape, X2.shape, X3.shape) --> (32, 64, 64, 3) (32, 64, 64, 3) (32, 64, 64, 3)
print("###")
print(len(next(train_pair_gen))) --> 3
我究竟做错了什么?
网络定义:
_________________________________________________________________ 层(类型)输出形状参数#
========================================= ======================== input_131 (InputLayer) (None, 64, 64, 3) 0
__________________________________________________________________ conv2d_100 (Conv2D) (None, 29, 29, 32) 4736
_________________________________________________________________ conv2d_101 (Conv2D) (无, 8, 8, 64) 73792
_________________________________________________________________ conv2d_102 (Conv2D) (无, 1, 1, 128) 204928
_________________________________________________________________ flatten_34 (展平) (无, 128) 0
==================================================== =============== 总参数:283,456 可训练参数:283,456 不可训练参数:0