0

我为连体网络创建了一个自定义生成器,该生成器为网络三元组提供图像:

def __getitem__(self, index):

...preprocessing of anchors positive and negative matches
    
return anchors, positives, negatives

从生成器创建 tf.data 数据集的正确方法是什么?如果我将三元组包装在元组/列表中,我在下面的尝试会导致生成器出错。

ds = tf.data.Dataset.from_generator(
lambda: train_gen, 
output_types=(tf.float32, tf.float32, tf.float32), 
output_shapes=([None,224,224,3], [None,224,224,3], [None,224,224,3])

但是如果我这样做了,我会得到同样的错误,但是当我喂它时来自模型ds

Layer model expects 3 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, None, None, None) dtype=float32>]

该错误是不言自明的。我只是不明白如何在创建数据集时[a, p, n]从函数中获取它,但不是在我调用时而是在使用时tf.data.dataset.from_generator().fit(ds).predict(ds)

感谢您提供的任何积分,谢谢!

4

0 回答 0