0

我目前正在使用 StyleGAN2 从我自己的数据集中生成图像。我在TF2.X 中使用 StyleGAN2 的这种实现,但我不知道如何实际添加具有多个标签的自定义数据集。生成没有标签的图像按预期工作。

现在我只是在使用 python 生成器和tf.data.Dataset.from_generator. 我的生成器产生一个元组(image, label1, label2, label3)并将它们传递给我的模型。这不起作用。

def create_dataset(image_size, batch_size):
dataset = tf.data.Dataset.from_generator(
    _data_generator,
    output_signature=(
        # Image must be in NCHW format
        tf.TensorSpec(shape=(3, image_size, image_size), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.float32)
    )
)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset

模型期望传递的对象具有相同的等级,我的图像具有形状: [4, 3, 64, 64]并且我的标签只是具有形状的简单浮点数[4]。(格式为 NCHW)

最初的StyleGAN2也没有具体描述类条件,而是使用dataset_tool.py文件来生成TFRecord文件。但是有人知道将标记数据传递给模型的另一种方法吗?

4

0 回答 0