-1

我正在联合学习中对数学手写符号进行分类,并按照教程使用联合学习进行图像分类。错误出现在服务器广播部分的结尾处,完整错误:-

>InvalidArgumentError: Graph execution error:

>2 root error(s) found.
>  (0) INVALID_ARGUMENT:  2 root error(s) found.
  (0) INVALID_ARGUMENT:  required broadcastable shapes
     [[{{node Equal}}]]
     [[Identity/_32]]
  (1) INVALID_ARGUMENT:  required broadcastable shapes
     [[{{node Equal}}]]
0 successful operations.
0 derived errors ignored.
     [[StatefulPartitionedCall/StatefulPartitionedCall/ReduceDataset]]
     [[Func/StatefulPartitionedCall/StatefulPartitionedCall/cond/else/_90/input/_118/_52]]
  (1) INVALID_ARGUMENT:  2 root error(s) found.
  (0) INVALID_ARGUMENT:  required broadcastable shapes
     [[{{node Equal}}]]
     [[Identity/_32]]
  (1) INVALID_ARGUMENT:  required broadcastable shapes
     [[{{node Equal}}]]
0 successful operations.
0 derived errors ignored.
     [[StatefulPartitionedCall/StatefulPartitionedCall/ReduceDataset]]
0 successful operations.
0 derived errors ignored. [Op:__inference_pruned_26214]

Keras 模型:-

def create_keras_model():
  cnn = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(45*45,)),
      tf.keras.layers.Dense(82, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
  return cnn

模型_fn()

def model_fn():
  # TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Model: "sequential_14"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_20 (Dense)            (None, 82)                166132    
                                                                 
 dense_21 (Dense)            (None, 82)                6806      
                                                                 
=================================================================
Total params: 172,938
Trainable params: 172,938
Non-trainable params: 0
__________________________

我一直在接受联合数据的培训:-

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

错误出现在

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
4

0 回答 0