4

I'm dealing with the 2D semantic segmentation task.

in Keras API Documents, these only has samples showing how to arrange dataset for image classification, not semantic segmentation.

So I arrange my image and label like this

SEED = 111
batch_size = 2
image_datagen = ImageDataGenerator(
    horizontal_flip=True,
    zca_epsilon=9,
    # fill_mode='nearest',
)
image_generator = image_datagen.flow_from_directory(
    directory="/xxx/images",
    class_mode=None,
    batch_size=batch_size,
    seed=SEED,
)


def preprocessing_function(image):
    return image.astype(np.uint8)


label_datagen = ImageDataGenerator(
    horizontal_flip=True,
    zca_epsilon=9,
    rescale=1,
    preprocessing_function=preprocessing_function,
    # fill_mode='nearest',
)
label_generator = image_datagen.flow_from_directory(
    directory="/xxx/labels",
    class_mode=None,
    batch_size=batch_size,
    seed=SEED,
)

train_generator = zip(image_generator, label_generator)
print(len(image_generator))
i = 0
for image_batch, label_batch in iter(train_generator):
    print(image_batch.shape, label_batch.shape) # (2, 256, 256, 3) (2, 256, 256, 3)
    print(image_batch.dtype, label_batch.dtype) # float32 float32
    i += 1
    if i == 5:
        break

But I found that the type of generated label images is float32, so I add a preprocessing_function function to label_datagen only to cast the dtype to uint8, but the generated label images' dtype is still float32, it seemed that preprocessing_function did nothing.

How can I repair this problem?

How to change my label data to uint8?

Is it a 'common practice' to add a preprocessing function to cast the dtype of label images?

Thanks for any advice!</p>

4

1 回答 1

0

我遇到了同样的问题并将生成器包装到另一个中。它有效,但它有点杂乱无章

label_generator = (x.astype(np.uint8) for x in label_generator)
train_generator = zip(image_generator, label_generator)
于 2019-06-17T20:50:11.207 回答