2

我想用每轮用户的随机样本来模拟这个联邦学习的代码进行图像分类,本教程使用所有客户端进行训练,实际上,我想以这样的方式修改这段代码,在每一轮随机样本客户被选中。所以我们可以在这段代码中改变什么来强制它随机选择客户端

import collections
import time

import tensorflow as tf
tf.compat.v1.enable_v2_behavior()

import tensorflow_federated as tff

source, _ = tff.simulation.datasets.emnist.load_data()


def map_fn(example):
  return collections.OrderedDict(
      x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])


def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).shuffle(500).batch(20).map(map_fn)


train_data = [client_data(n) for n in range(10)]
element_spec = train_data[0].element_spec

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
  return tff.learning.from_keras_model(
      model,
      input_spec=element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])


trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))

....
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = trainer.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
4

2 回答 2

1

tff.simulation.ClientData对象公开一个client_ids属性,该属性表示标识此数据集中用户的字符串列表。

因此,您可以直接从该列表中采样,并create_tf_dataset_for_client在同一对象上使用该方法创建该用户数据的数据集。假设一个tff.simulation.ClientDataobject client_data,伪代码看起来像:

import random
...

for round_num in range(2, NUM_ROUNDS):
  selected_clients = random.sample(client_data.client_ids, USERS_PER_ROUND)
  federated_data = [
      client_data.create_tf_dataset_for_client(n) for n in selected_clients]
  state, metrics = iterative_process.next(state, federated_data)

TFF 中包含的许多研究代码在某种程度上将选择客户的关注与运行训练循环分开,所以我不能真正指出这种模式的一个很好的例子——但我认为 TFF 很乐意接受贡献更新教程使用这样的模式,以帮助ClientData更好地展示 API 的灵活性。

于 2020-04-04T20:00:08.490 回答
0

这将做到(遵循基思的伪代码):

selected_clients = np.random.choice(emnist_train.client_ids, size=USERS_PER_ROUND)
round_federated_train_data = make_federated_data(emnist_train, selected_clients)
于 2021-06-11T11:14:41.810 回答