这可能在模拟期间写在外部 Python 循环中。当前的 API 没有在一轮内同时进行评估和训练的概念。
如果使用 TFF 中包含的模拟数据集(例如 下的那些tff.simulation.datasets
),它们包括一个训练/测试拆分,这使得这很容易。每个返回一个tff.simulation.ClientData对象的 2 元组,一个 test 和一个 train ClientData
。test 和 train 都具有相同的ClientData.client_id
列表,但tf.data.Dataset
返回的create_tf_dataset_for_client(client_id)
将有一组不相交的示例。
换句话说,训练和测试的划分是针对用户示例,而不是针对用户。
联合训练和联合评估循环可能如下所示:
train_data, test_data = tff.simulation.datasets.shakespeare.load_data()
federated_average = tff.learning.build_federated_averaging_process(model_fn, ...)
federated_eval = tff.learning.build_federated_evaluation(model_fn)
state = federated_average.initialize()
for _ in range(NUM_ROUNDS):
participating_clients = numpy.random.choice(train_data.client_ids, size=5)
# Run a training pass
clients_train_datasets = [
train_data.create_tf_dataset_for_client(c) for c in participating_clients
]
state, train_metrics = federated_average.next(state, client_train_datasets)
# Run an evaluation pass
client_eval_datasets = [
test_data.create_tf_dataset_for_client(c) for c in participating_clients
]
eval_metrics = federated_eval(state.model, client_eval_datasets)