8

我想将这篇文中的循环自动编码器改编为在联合环境中工作。

我稍微修改了模型以符合TFF 图像分类教程中显示的示例。

def create_compiled_keras_model():
  model = tf.keras.models.Sequential([
      tf.keras.layers.LSTM(2, input_shape=(10, 2), name='Encoder'),
      tf.keras.layers.RepeatVector(10, name='Latent'),
      tf.keras.layers.LSTM(2, return_sequences=True, name='Decoder')]
  )

  model.compile(loss='mse', optimizer='adam')
  return model

model = create_compiled_keras_model()

sample_batch = gen(1)
timesteps, input_dim = 10, 2

def model_fn():
  keras_model = create_compiled_keras_model()
  return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

gen 函数定义如下:

import random

def gen(batch_size):
    seq_length = 10

    batch_x = []
    batch_y = []

    for _ in range(batch_size):
        rand = random.random() * 2 * np.pi

        sig1 = np.sin(np.linspace(0.0 * np.pi + rand, 3.0 * np.pi + rand, seq_length * 2))
        sig2 = np.cos(np.linspace(0.0 * np.pi + rand, 3.0 * np.pi + rand, seq_length * 2))

        x1 = sig1[:seq_length]
        y1 = sig1[seq_length:]
        x2 = sig2[:seq_length]
        y2 = sig2[seq_length:]

        x_ = np.array([x1, x2])
        y_ = np.array([y1, y2])
        x_, y_ = x_.T, y_.T

        batch_x.append(x_)
        batch_y.append(y_)

    batch_x = np.array(batch_x)
    batch_y = np.array(batch_y)

    return batch_x, batch_x #batch_y

到目前为止,我一直无法找到任何不使用 TFF 存储库中的示例数据的文档。

如何修改它以创建联合数据集并开始训练?

4

2 回答 2

5

在非常高的层次上,要使用带有 TFF 的任意数据集,需要以下步骤:

  1. 将数据集划分为每个客户端子集(如何做到这一点是一个更大的问题)
  2. 为每个客户端子集创建一个tf.data.Dataset
  3. 将所有(或子集)数据集对象的列表传递给联合优化。

教程中发生了什么

图像分类的联邦学习教程使用tff.learning.build_federated_averaging_process来构建使用 FedAvg 算法的联邦优化。

在该笔记本中,以下代码正在执行一轮联合优化,其中客户端数据集被传递给 process'.next方法:

   state, metrics = iterative_process.next(state, federated_train_data)

federated_train_data是一个 Python listtf.data.Dataset每个参与这一轮的客户一个。

ClientData 对象

TFF 提供的预制数据集(在tff.simulation.datasets下)使用tff.simulation.ClientData接口实现,该接口管理客户端 → 数据集映射和tff.data.Dataset创建。

如果您计划重复使用数据集,则将其实现为 atff.simulation.ClientData可能会使将来的使用更容易。

于 2019-04-01T20:15:23.590 回答
1

接受的答案得到了很好的解释。如果你们需要将张量转换为 clientdata 对象的代码实现,可以在这个github 存储库中找到。

我曾经tff.simulation.FromTensorSlicesClientData将 mnist 数据集转换为多个 tff 客户端数据。

于 2020-05-08T09:28:22.113 回答