我为连体网络创建了一个自定义生成器,该生成器为网络三元组提供图像:
def __getitem__(self, index):
...preprocessing of anchors positive and negative matches
return anchors, positives, negatives
从生成器创建 tf.data 数据集的正确方法是什么?如果我将三元组包装在元组/列表中,我在下面的尝试会导致生成器出错。
ds = tf.data.Dataset.from_generator(
lambda: train_gen,
output_types=(tf.float32, tf.float32, tf.float32),
output_shapes=([None,224,224,3], [None,224,224,3], [None,224,224,3])
但是如果我这样做了,我会得到同样的错误,但是当我喂它时来自模型ds
。
Layer model expects 3 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, None, None, None) dtype=float32>]
该错误是不言自明的。我只是不明白如何在创建数据集时[a, p, n]
从函数中获取它,但不是在我调用时而是在使用时tf.data.dataset.from_generator()
.fit(ds)
.predict(ds)
感谢您提供的任何积分,谢谢!