我从我自己的数据集中用 TFF 编写了一个代码,除了这一行之外,所有代码都可以正常运行
在 train_data 中,我制作了 4 个数据集,加载了 tf.data.Dataset,它们的类型为“DatasetV1Adapter”
def client_data(n):
ds = source.create_tf_dataset_for_client(source.client_ids[n])
return ds.repeat(10).map(map_fn).shuffle(500).batch(20)
federated_train_data = [client_data(n) for n in range(4)]
batch = tf.nest.map_structure(lambda x: x.numpy(), iter(train_data[0]).next())
def model_fn():
model = tf.keras.models.Sequential([
.........
return tff.learning.from_compiled_keras_model(model, batch)
所有这些都运行正确,我得到了教练和状态:
trainer = tff.learning.build_federated_averaging_process(model_fn)
除了,当我要开始训练和使用这段代码时:
state, metrics = iterative_process.next(state, federated_train_data)
print('round 1, metrics={}'.format(metrics))
我不能。错误来了!那么,错误可能来自哪里?从数据集的类型?或者我让我的数据联合的方式?