0

我想将使用 FedAvg 算法训练的 TensorFlow 联合模型保存为 Keras/.h5 模型。我找不到这方面的文件,想知道如何做。此外,如果可能的话,我想同时访问聚合服务器模型和客户端模型。

我用来训练联合模型的代码如下:

def model_fn():
    model = tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(segment_size,num_input_channels)),
      tf.keras.layers.Flatten(), 
      tf.keras.layers.Dense(units=400, activation='relu'),
      tf.keras.layers.Dropout(dropout_rate),
      tf.keras.layers.Dense(units=100, activation='relu'),
      tf.keras.layers.Dropout(dropout_rate),
      tf.keras.layers.Dense(activityCount, activation='softmax'),
    ])
    return tff.learning.from_keras_model(
      model,
      dummy_batch=batch,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learningRate))

def evaluate(num_rounds=communicationRound):
  state = trainer.initialize()
  roundMetrics = []
  evaluation = tff.learning.build_federated_evaluation(model_fn)

  for round_num in range(num_rounds):
    t1 = time.time()
    state, metrics = trainer.next(state, train_data)
    t2 = time.time()
    test_metrics = evaluation(state.model, train_data)

    roundMetrics.append('round {:2d}, metrics={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
    roundMetrics.append("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
    roundMetrics.append('round time={}'.format(t2 - t1))
    print('round {:2d}, accuracy={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
    print("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
    print('round time={}'.format(t2 - t1))
  outF = open(filepath+'stats'+architectureType+'.txt', "w")
  for line in roundMetrics:
    outF.write(line)
    outF.write("\n")
  outF.close()
4

1 回答 1

3

粗略地说,我们将使用 save_checkpoint/load_checkpoint 方法。特别是,您可以实例化 FileCheckpointManager,并要求它(几乎)直接保存状态。

您的示例中的 state 是 tff.python.common_libs.anonymous_tuple.AnonymousTuple (IIRC) 的一个实例,它与 tf.convert_to_tensor 不兼容,正如 save_checkpoint 所需要并在其文档字符串中声明的那样。TFF 研究代码中经常使用的通用解决方案是引入一个 Python attrs 类,以便在返回状态后立即从匿名元组转换——

假设上述情况,以下草图应该有效:

# state assumed an anonymous tuple, previously created
# N some integer 

ckpt_manager = FileCheckpointManager(...)
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=N)

要从此检查点恢复,您可以随时调用:

state = iterative_process.initialize()
ckpt_manager = FileCheckpointManager(...)
restored_state = ckpt_manager.load_latest_checkpoint(
    ServerState.from_anon_tuple(state))

需要注意的一点:上面链接的代码指针一般在tff.python.research...中,pip包中不包含;因此,获得它们的首选方法是将代码分叉到您自己的项目中,或者拉下存储库并从源代码构建它。

于 2020-03-28T09:37:47.590 回答