0

我有一个自定义损失函数,用于实现 fedprox 算法

该函数包装了提供给模型的实际损失函数。

def penalty_loss_func(local_model, global_model, mu, loss_func):
    def my_loss_func(y_true, y_pred):
        model_difference = nest.map_structure(lambda a, b: a - b, local_model.weights, global_model.weights)
        squared_norm = square(linalg.global_norm(model_difference))
        return loss_func(y_true, y_pred) + multiply(multiply(mu,0.5),squared_norm)
    return my_loss_func

我在同一台机器上运行连续的训练(每个会话代表一个独立的客户端),但每次运行训练都会显着变慢。

我尝试了默认损失,除非我使用自定义损失函数,否则我没有这个问题。

什么可能导致这种放缓?

4

0 回答 0