我正在与 Tensorflow 联合开展一个项目。我已经设法使用 TensorFlow Federated Learning 模拟提供的库来加载、训练和测试一些数据集。
例如,我加载了 emnist 数据集
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
它得到了 load_data() 返回的数据集作为 tff.simulation.ClientData 的实例。这是一个允许我迭代客户端 ID 并允许我选择数据子集进行模拟的界面。
len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0])
我正在尝试使用 Keras 加载 fashion_mnist 数据集以执行一些联合操作:
fashion_train,fashion_test=tf.keras.datasets.fashion_mnist.load_data()
但我得到这个错误
AttributeError: 'tuple' object has no attribute 'element_spec'
因为 Keras 返回一个 Numpy 数组的元组,而不是像以前一样的 tff.simulation.ClientData:
def tff_model_fn() -> tff.learning.Model:
return tff.learning.from_keras_model(
keras_model=factory.retrieve_model(True),
input_spec=fashion_test.element_spec,
loss=loss_builder(),
metrics=metrics_builder())
iterative_process = tff.learning.build_federated_averaging_process(
tff_model_fn, Parameters.server_adam_optimizer_fn, Parameters.client_adam_optimizer_fn)
server_state = iterative_process.initialize()
总结一下,
有什么方法可以
tff.simulation.ClientData从 Keras Tuple Numpy 数组创建元组元素?Another solution that comes to my mind is to use the
tff.simulation.HDF5ClientDataand load manually the appropriate files in aHDF5format(train.h5, test.h5)in order to get thetff.simulation.ClientData, but my problem is that i cant find the url for fashion_mnistHDF5file format i mean something like that for both train and test:fileprefix = 'fed_emnist_digitsonly' sha256 = '55333deb8546765427c385710ca5e7301e16f4ed8b60c1dc5ae224b42bd5b14b' filename = fileprefix + '.tar.bz2' path = tf.keras.utils.get_file( filename, origin='https://storage.googleapis.com/tff-datasets-public/' + filename, file_hash=sha256, hash_algorithm='sha256', extract=True, archive_format='tar', cache_dir=cache_dir) dir_path = os.path.dirname(path) train_client_data = hdf5_client_data.HDF5ClientData( os.path.join(dir_path, fileprefix + '_train.h5')) test_client_data = hdf5_client_data.HDF5ClientData( os.path.join(dir_path, fileprefix + '_test.h5')) return train_client_data, test_client_data
My final goal is to make the fashion_mnist dataset work with the TensorFlow federated learning.