1

我尝试在 Tensorflow Federated 的“图像分类”教程中自定义模型。(它最初使用的是顺序模型)我使用 Keras ResNet50 但是当它开始训练时,总是出现错误“不兼容的形状”

这是我的代码:

NUM_CLIENTS = 4
NUM_EPOCHS = 10
BATCH_SIZE = 2
SHUFFLE_BUFFER = 5

def create_compiled_keras_model():
  model = tf.keras.applications.resnet.ResNet50(include_top=False, weights='imagenet', 
                                                input_tensor=tf.keras.layers.Input(shape=(100, 
                                                300, 3)), pooling=None)

  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  return model


def model_fn():
  keras_model = create_compiled_keras_model()
  return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

iterative_process = tff.learning.build_federated_averaging_process(model_fn)

错误信息: 在此处输入图片描述

我觉得形状不兼容,因为时代和客户信息不知何故丢失了。如果有人能给我一个提示,将非常感激。

更新:

断言错误发生在tff.learning.build_federated_averaging_process

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-164-dac26193d9d8> in <module>()
----> 1 iterative_process = tff.learning.build_federated_averaging_process(model_fn)
      2 
      3 # iterative_process = build_federated_averaging_process(model_fn)

13 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/federated_averaging.py in build_federated_averaging_process(model_fn, server_optimizer_fn, client_weight_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    165   return optimizer_utils.build_model_delta_optimizer_process(
    166       model_fn, client_fed_avg, server_optimizer_fn,
--> 167       stateful_delta_aggregate_fn, stateful_model_broadcast_fn)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py in build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    349   # still need this.
    350   with tf.Graph().as_default():
--> 351     dummy_model_for_metadata = model_utils.enhance(model_fn())
    352 
    353   # ===========================================================================

<ipython-input-159-b2763ace8e5b> in model_fn()
      1 def model_fn():
      2   keras_model = model
----> 3   return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/keras_utils.py in from_compiled_keras_model(keras_model, dummy_batch)
    211   # Model.test_on_batch() once before asking for metrics.
    212   if isinstance(dummy_tensors, collections.Mapping):
--> 213     keras_model.test_on_batch(**dummy_tensors)
    214   else:
    215     keras_model.test_on_batch(*dummy_tensors)

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in test_on_batch(self, x, y, sample_weight, reset_metrics)
   1007         sample_weight=sample_weight,
   1008         reset_metrics=reset_metrics,
-> 1009         standalone=True)
   1010     outputs = (
   1011         outputs['total_loss'] + outputs['output_losses'] + outputs['metrics'])

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in test_on_batch(model, x, y, sample_weight, reset_metrics, standalone)
    503       y,
    504       sample_weights=sample_weights,
--> 505       output_loss_metrics=model._output_loss_metrics)
    506 
    507   if reset_metrics:

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    568         xla_context.Exit()
    569     else:
--> 570       result = self._call(*args, **kwds)
    571 
    572     if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    606       # In this case we have not created variables on the first call. So we can
    607       # run the first trace but we should fail if variables are created.
--> 608       results = self._stateful_fn(*args, **kwds)
    609       if self._created_variables:
    610         raise ValueError("Creating variables on a non-first call to a function"

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   2407     """Calls a graph function specialized to the inputs."""
   2408     with self._lock:
-> 2409       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2410     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2411 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2765 
   2766       self._function_cache.missed.add(call_context_key)
-> 2767       graph_function = self._create_graph_function(args, kwargs)
   2768       self._function_cache.primary[cache_key] = graph_function
   2769       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2655             arg_names=arg_names,
   2656             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2657             capture_by_value=self._capture_by_value),
   2658         self._function_attributes,
   2659         # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

AssertionError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_eager.py:345 test_on_batch  *
        with backend.eager_learning_phase_scope(0):
    /usr/lib/python3.6/contextlib.py:81 __enter__
        return next(self.gen)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py:425 eager_learning_phase_scope
        assert ops.executing_eagerly_outside_functions()

    AssertionError: 

4

2 回答 2

0

啊,我相信这个问题来自对sample_batch. TFF 传递sample_batch给 Keras,它调用这个样本批次的前向传递来初始化 keras 模型的各种属性。sample_batch应该是您将要在服务器端提供模型的文字数据的样本,或者是与您将传入的数据的形状和类型相匹配的一批假数据。

可以在此处找到前者的示例(此处使用tf.data.Dataset),在测试代码中可以找到后者的几个示例,例如此处

从我看到的模型定义来看,x您的 sample_batch 的元素可能应该是一个ndarray形状[2, 100, 300, 3](其中 2 代表批量大小,但从技术上讲,这可以是任何非零维度),并且该y元素也应该与预期的y结构相匹配在您使用的数据中。

我希望这会有所帮助,如果有任何问题,请回复!

有一点需要注意,这可能有助于思考 TFF——TFF 正在构建一个语法树,表示您通过build_federated_averaging_process. 此错误实际上发生在此对象的构造过程中。TFF 必须跟踪您传递给它的计算,以便知道要生成什么结构,这就是这里提出的问题。当您调用返回nextIterativeProcess.

于 2020-01-08T16:52:22.523 回答
0

我有同样的问题:如果我执行这条线状态,metrics = iterative_process.next(state, federated_train_data) print('round 1, metrics={}'.format(metrics))

我发现这个错误 InvalidArgumentError: 2 root error(s) found。(0) 无效参数:默认 MaxPoolingOp 仅支持设备类型 CPU 上的 NHWC [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset]] [[subcomputation/StatefulPartitionedCall_1/ ReduceDataset/_140]] (1) 无效参数:默认 MaxPoolingOp 仅支持设备类型 CPU 上的 NHWC [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset]] 0成功的操作。0 派生错误被忽略。

知道我使用 VGG16 你对这种类型的错误有什么想法吗

于 2020-01-30T09:12:34.983 回答