1

我正在 使用联邦学习进行手写数学的符号分类。我已经对图像进行了预处理,keras.preprocessing.image.ImageDataGenerator并获得了每个图像的标签。

from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
train_dataset = train_datagen.flow_from_directory(
        'train_test_data/train/',
        target_size=(45,45),
        batch_size=32,
        class_mode='categorical')

获取标签:

import os
# make label list '!/exp87530.jpg'
def make_labels(train_dataset):
  labels = train_dataset.filenames
  label = []
  for l in labels:
    l = l.split(os.path.sep)[0]
    label.append(l)
  return label 

如何制作需要发送给客户的扁平图像和标签元组?从 tensorflow 教程Building Your Own Federated Learning Algorithm中可以看出

从教程:

import 
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)
4

1 回答 1

1

你可以尝试这样的事情:

import tensorflow as tf

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)


ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers, target_size=(45,45), batch_size=10, shuffle=True),
    output_types=(tf.float32, tf.int32))

def preprocess(dataset):

  def batch_format_fn(x, y):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(x, [-1, 45*45*3]), 
            tf.reshape(y, [-1, 5]))

  return dataset.map(batch_format_fn)

ds = preprocess(ds)
for x,y in ds.take(1):
  print(x.shape, y.shape)

扁平化批次数据,其中 5 是类/不同标签的数量:

(10, 6075) (10, 5)
于 2022-02-04T18:20:40.340 回答