我正在联合学习中对数学手写符号进行分类,并按照教程使用联合学习进行图像分类。错误出现在服务器广播部分的结尾处,完整错误:-
>InvalidArgumentError: Graph execution error:
>2 root error(s) found.
> (0) INVALID_ARGUMENT: 2 root error(s) found.
(0) INVALID_ARGUMENT: required broadcastable shapes
[[{{node Equal}}]]
[[Identity/_32]]
(1) INVALID_ARGUMENT: required broadcastable shapes
[[{{node Equal}}]]
0 successful operations.
0 derived errors ignored.
[[StatefulPartitionedCall/StatefulPartitionedCall/ReduceDataset]]
[[Func/StatefulPartitionedCall/StatefulPartitionedCall/cond/else/_90/input/_118/_52]]
(1) INVALID_ARGUMENT: 2 root error(s) found.
(0) INVALID_ARGUMENT: required broadcastable shapes
[[{{node Equal}}]]
[[Identity/_32]]
(1) INVALID_ARGUMENT: required broadcastable shapes
[[{{node Equal}}]]
0 successful operations.
0 derived errors ignored.
[[StatefulPartitionedCall/StatefulPartitionedCall/ReduceDataset]]
0 successful operations.
0 derived errors ignored. [Op:__inference_pruned_26214]
Keras 模型:-
def create_keras_model():
cnn = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(45*45,)),
tf.keras.layers.Dense(82, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
return cnn
模型_fn()
def model_fn():
# TFF will call this within different graph contexts.
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=federated_train_data[0].element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
Model: "sequential_14"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_20 (Dense) (None, 82) 166132
dense_21 (Dense) (None, 82) 6806
=================================================================
Total params: 172,938
Trainable params: 172,938
Non-trainable params: 0
__________________________
我一直在接受联合数据的培训:-
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
错误出现在
state, metrics = iterative_process.next(state, federated_train_data)
print('round 1, metrics={}'.format(metrics))