我有一个自定义损失函数,用于实现 fedprox 算法
该函数包装了提供给模型的实际损失函数。
def penalty_loss_func(local_model, global_model, mu, loss_func):
def my_loss_func(y_true, y_pred):
model_difference = nest.map_structure(lambda a, b: a - b, local_model.weights, global_model.weights)
squared_norm = square(linalg.global_norm(model_difference))
return loss_func(y_true, y_pred) + multiply(multiply(mu,0.5),squared_norm)
return my_loss_func
我在同一台机器上运行连续的训练(每个会话代表一个独立的客户端),但每次运行训练都会显着变慢。
我尝试了默认损失,除非我使用自定义损失函数,否则我没有这个问题。
什么可能导致这种放缓?