我在联邦图像分类教程中实现了 Resnet34 模型。10轮之后,训练准确率可以高于90%,然而,使用最后一轮的评估准确率state.model
总是在50%左右。
evaluation = tff.learning.build_federated_evaluation(model_fn)
federated_test_data = make_federated_data(emnist_test, sample_clients)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
我很困惑评估部分可能有什么问题?此外,我打印了服务器模型的不可训练变量(BatchNorm 中的均值和方差),它们是 0 和 1,在这些轮次之后没有更新/平均。他们应该是那样还是那样可能是问题所在?非常感谢!
更新:
准备训练数据和打印结果的代码:
len(emnist_train.client_ids)
4
emnist_train.element_type_structure
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int64, name=None)),('pixels',TensorSpec(shape=(256, 256, 3), dtype=tf.float32, name=None))])
NUM_CLIENTS = 4
NUM_EPOCHS = 1
BATCH_SIZE = 30
SHUFFLE_BUFFER = 500
def preprocess(dataset):
def element_fn(element):
return collections.OrderedDict([
('x', element['pixels']),
('y', tf.reshape(element['label'], [1])),
])
return dataset.repeat(NUM_EPOCHS).map(element_fn).shuffle(
SHUFFLE_BUFFER).batch(BATCH_SIZE)
sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]
federated_train_data = make_federated_data(emnist_train, sample_clients)
preprocessed_example_dataset = preprocess(example_dataset)
sample_batch = tf.nest.map_structure(
lambda x: x.numpy(), iter(preprocessed_example_dataset).next())
def make_federated_data(client_data, client_ids):
return [preprocess(client_data.create_tf_dataset_for_client(x))
for x in client_ids]
len(federated_train_data), federated_train_data[0]
(4,<BatchDataset shapes: OrderedDict([(x, (None, 256, 256, 3)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int64)])>)
训练和评估代码:
def create_compiled_keras_model():
base_model = tf.keras.applications.resnet.ResNet50(include_top=False, weights='imagenet', input_shape=(256,256,3,))
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(2, activation='softmax')
model = tf.keras.Sequential([
base_model,
global_average_layer,
prediction_layer
])
model.compile(optimizer = tf.keras.optimizers.SGD(lr = 0.001, momentum=0.9), loss = tf.keras.losses.SparseCategoricalCrossentropy(), metrics = [tf.keras.metrics.SparseCategoricalAccuracy()])
return model
def model_fn():
keras_model = create_compiled_keras_model()
return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()
for round_num in range(2, 12):
state, metrics = iterative_process.next(state, federated_train_data)
print('round {:2d}, metrics={}'.format(round_num, metrics, state))
evaluation = tff.learning.build_federated_evaluation(model_fn)
federated_test_data = make_federated_data(emnist_test, sample_clients)
len(federated_test_data), federated_test_data[0]
(4,
<BatchDataset shapes: OrderedDict([(x, (None, 256, 256, 3)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int64)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
每一轮之后的训练和评估结果:
round 1, metrics=<sparse_categorical_accuracy=0.5089045763015747,loss=0.7813001871109009,keras_training_time_client_sum_sec=0.008826255798339844>
<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>
round 2, metrics=<sparse_categorical_accuracy=0.519825279712677,loss=0.7640910148620605,keras_training_time_client_sum_sec=0.011750459671020508>
<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>
round 3, metrics=<sparse_categorical_accuracy=0.5099126100540161,loss=0.7513422966003418,keras_training_time_client_sum_sec=0.0039823055267333984>
<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>
round 4, metrics=<sparse_categorical_accuracy=0.5278897881507874,loss=0.7905193567276001,keras_training_time_client_sum_sec=0.0010638236999511719>
<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>
round 5, metrics=<sparse_categorical_accuracy=0.5199933052062988,loss=0.7782396674156189,keras_training_time_client_sum_sec=0.012729644775390625>
<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>