我设计了一个函数来计算损失和model.trainable_variables
使用 Tensorflow的梯度GardientTape
。我使用这个函数来执行拆分学习,这意味着模型在客户端上被划分和训练到特定层。客户端模型的输出、标签和可训练变量被发送到服务器以完成模型后半部分的训练。在服务器上,应在服务器上使用此函数计算应发送回客户端以更新客户端模型的服务器端梯度和客户端梯度:
def calc_gradients(self, msg):
with tf.GradientTape(persistent=True) as tape:
output_client, labels, trainable_variables_client = msg["client_out"], msg["label"], msg["trainable_variables"]
output_server = self.call(output_client)
l = self.loss(output_server, labels)
gradient_server = tape.gradient(l, self.model.trainable_variables)
print(l)
print(trainable_variables_client)
gradient_client = tape.gradient(l, trainable_variables_client)
print("client gradient:")
print(gradient_client)
服务器端梯度计算正确。损失也被正确计算并且服务器的 trainable_variables 被正确接收,但客户端梯度gradient_client = tape.gradient(l, trainable_variables_client)
只返回:
client gradient:
[None, None, None, None]
msg 是一个包含数据的字典,从客户端发送到服务器:
def start_train(self, batchround):
if batchround in range(self.num_samples // self.batch_size):
with tf.GradientTape(persistent=True) as tape:
output_client, labels = self.send(batchround)
client_trainable_variables = self.model.trainable_variables
msg = {
'client_out': output_client,
'label': labels,
'trainable_variables': client_trainable_variables,
'batchround': batchround
}