我按照这个 emnist 教程创建了一个图像分类实验(7 个类),目的是使用 TFF 框架在 3 个数据孤岛上训练分类器。
在训练开始之前,我将模型转换为 tf keras 模型,tff.learning.assign_weights_to_keras_model(model,state.model)
用于评估我的验证集。无论标签如何,该模型都只预测一个类别。这是可以预料的,因为尚未对模型进行训练。但是,我在每轮联合平均后重复此步骤,但问题仍然存在。所有验证图像都被预测为一类。我还在每一轮之后保存 tf keras 模型的权重,并对测试集进行预测——没有变化。
我已采取一些步骤来检查问题的根源:
- 检查 tf keras 模型权重是否在每轮后转换 FL 模型时更新 - 它们正在更新。
- 确保缓冲区大小大于每个客户端的训练数据集大小。
- 将预测与训练数据集中的类分布进行比较。存在类别不平衡,但模型预测的类别不一定是多数类别。此外,它并不总是同一个类。在大多数情况下,它只预测 0 类。
- 将轮数增加到 5 轮,将每轮 epochs 增加到 10。这是计算量非常大的,因为它是一个相当大的模型,每个客户端大约有 1500 张图像进行训练。
- 调查每次训练尝试的 TensorBoard 日志。随着回合的进行,训练损失正在减少。
- 尝试了一个更简单的模型 - 具有 2 个卷积层的基本 CNN。这使我能够大大增加 epochs 和 rounds 的数量。在测试集上评估这个模型时,它预测了 4 个不同的类别,但性能仍然很差。这表明我只需要增加原始模型的轮数和时期数即可增加预测的变化。这很困难,因为这会导致大量的训练时间。
型号详情:
该模型使用 XceptionNet 作为未冻结权重的基础模型。当所有训练图像都汇集到一个全局数据集中时,这在分类任务上表现良好。我们的目标是希望获得与 FL 相当的性能。
base_model = Xception(include_top=False,
weights=weights,
pooling='max',
input_shape=input_shape)
x = GlobalAveragePooling2D()( x )
predictions = Dense( num_classes, activation='softmax' )( x )
model = Model( base_model.input, outputs=predictions )
这是我的训练代码:
def fit(self):
"""Train FL model"""
# self.load_data()
summary_writer = tf.summary.create_file_writer(
self.logs_dir
)
federated_averaging = self._construct_iterative_process()
state = federated_averaging.initialize()
tfkeras_model = self._convert_to_tfkeras_model( state )
print( np.argmax( tfkeras_model.predict( self.val_data ), axis=-1 ) )
val_loss, val_acc = tfkeras_model.evaluate( self.val_data, steps=100 )
with summary_writer.as_default():
for round_num in tqdm( range( 1, self.num_rounds ), ascii=True, desc="FedAvg Rounds" ):
print( "Beginning fed avg round..." )
# Round of federated averaging
state, metrics = federated_averaging.next(
state,
self.training_data
)
print( "Fed avg round complete" )
# Saving logs
for name, value in metrics._asdict().items():
tf.summary.scalar(
name,
value,
step=round_num
)
print( "round {:2d}, metrics={}".format( round_num, metrics ) )
tff.learning.assign_weights_to_keras_model(
tfkeras_model,
state.model
)
# tfkeras_model = self._convert_to_tfkeras_model(
# state
# )
val_metrics = {}
val_metrics["val_loss"], val_metrics["val_acc"] = tfkeras_model.evaluate(
self.val_data,
steps=100
)
for name, metric in val_metrics.items():
tf.summary.scalar(
name=name,
data=metric,
step=round_num
)
self._checkpoint_tfkeras_model(
tfkeras_model,
round_num,
self.checkpoint_dir
)
def _checkpoint_tfkeras_model(self,
model,
round_number,
checkpoint_dir):
# Obtaining model dir path
model_dir = os.path.join(
checkpoint_dir,
f'round_{round_number}',
)
# Creating directory
pathlib.Path(
model_dir
).mkdir(
parents=True
)
model_path = os.path.join(
model_dir,
f'model_file_round{round_number}.h5'
)
# Saving model
model.save(
model_path
)
def _convert_to_tfkeras_model(self, state):
"""Converts global TFF modle of TF keras model
Takes the weights of the global model
and pushes them back into a standard
Keras model
Args:
state: The state of the FL server
containing the model and
optimization state
Returns:
(model); TF Keras model
"""
model = self._load_tf_keras_model()
model.compile(
loss=self.loss,
metrics=self.metrics
)
tff.learning.assign_weights_to_keras_model(
model,
state.model
)
return model
def _load_tf_keras_model(self):
"""Loads tf keras models
Raises:
KeyError: A model name was not defined
correctly
Returns:
(model): TF keras model object
"""
model = create_models(
model_type=self.model_type,
input_shape=[self.img_h, self.img_w, 3],
freeze_base_weights=self.freeze_weights,
num_classes=self.num_classes,
compile_model=False
)
return model
def _define_model(self):
"""Model creation function"""
model = self._load_tf_keras_model()
tff_model = tff.learning.from_keras_model(
model,
dummy_batch=self.sample_batch,
loss=self.loss,
# Using self.metrics throws an error
metrics=[tf.keras.metrics.CategoricalAccuracy()] )
return tff_model
def _construct_iterative_process(self):
"""Constructing federated averaging process"""
iterative_process = tff.learning.build_federated_averaging_process(
self._define_model,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD( learning_rate=0.02 ),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD( learning_rate=1.0 ) )
return iterative_process