我想用我的 Tensorflow Federated 代码获得可重现的结果。为此,我实施了一些种子(随机、numpy 和 tensorflow),但它们不会影响 Tensorflow Federated。数据处理步骤都是可重现的,它必须在下面的代码片段中。
我读过 Tensorflow Federated 不提供全局种子功能,我唯一的可能就是保存状态。但我不明白这个论点。有没有人知道可以帮助我或向我解释为什么我不能在 Tensorflow Federated 中使用种子的方法/功能?
感谢每条评论:)感谢您的帮助。
nest_asyncio.apply()
seed_value = 0
random.seed(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)
# designing the clients
client_train_data = collections.OrderedDict()
for i in range(1, num_clients+1):
client_name = "Client_{}".format(i)
size = len(X_train)//num_clients
start = size * (i-1)
end = size * i
data = collections.OrderedDict((("label", y_train[start:end]),
("features", X_train[start:end])))
client_train_data[client_name] = data
train_dataset = tff.simulation.FromTensorSlicesClientData(client_train_data)
def preprocess(dataset):
def batch_format(element):
return collections.OrderedDict(
x = reshape(element["features"], [-1, 11]),
y = reshape(element["label"], [-1, 1]))
return dataset.repeat(num_epochs).shuffle(shuffle_buffer).batch(
batch_size).map(batch_format).prefetch(prefetch_buffer)
def make_federated_data(client_data, client_ids):
return [
preprocess(client_data.create_tf_dataset_for_client(x))
for x in client_ids
]
fl_train_data = make_federated_data(train_dataset, train_dataset.client_ids)
def create_keras_model():
model = Sequential()
model.add(Dense(15, input_dim=11, activation="relu"))
model.add(Dense(15, activation="relu"))
model.add(Dense(1, activation="sigmoid"))
return model
def model_fl():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=fl_train_data[0].element_spec,
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryAccuracy()])
fl_process = tff.learning.build_federated_averaging_process(
model_fl,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.00))
# initialize federated averaging
state = fl_process.initialize()
# federated rounds
for round in range(1, num_rounds+1):
state, metrics = fl_process.next(state, fl_train_data)
print("Runde {:2d}, metrics={}".format(round, metrics))