我目前正在使用 StyleGAN2 从我自己的数据集中生成图像。我在TF2.X 中使用 StyleGAN2 的这种实现,但我不知道如何实际添加具有多个标签的自定义数据集。生成没有标签的图像按预期工作。
现在我只是在使用 python 生成器和tf.data.Dataset.from_generator
. 我的生成器产生一个元组(image, label1, label2, label3)
并将它们传递给我的模型。这不起作用。
def create_dataset(image_size, batch_size):
dataset = tf.data.Dataset.from_generator(
_data_generator,
output_signature=(
# Image must be in NCHW format
tf.TensorSpec(shape=(3, image_size, image_size), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.float32)
)
)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
模型期望传递的对象具有相同的等级,我的图像具有形状:
[4, 3, 64, 64]
并且我的标签只是具有形状的简单浮点数[4]
。(格式为 NCHW)
最初的StyleGAN2也没有具体描述类条件,而是使用dataset_tool.py
文件来生成TFRecord
文件。但是有人知道将标记数据传递给模型的另一种方法吗?