这是我的模型,我在 TensorFlow 中实现过一次。
def create_compiled_keras_model():
inputs = Input(shape=(7, 20, 1))
l0_c = Conv2D(32, kernel_size=(7, 7), padding='valid', activation='relu')(inputs)
l1_c = Conv2D(32, kernel_size=(1, 5), padding='same', activation='relu')(l0_c)
l1_p = AveragePooling2D(pool_size=(1, 2), strides=2, padding='same')(l1_c)
l2_c = Conv2D(64, kernel_size=(1, 4), padding='same', activation='relu')(l1_p)
l2_p = AveragePooling2D(pool_size=(1, 2), strides=2, padding='same')
l3_c = Conv2D(2, kernel_size=(1, 1), padding='valid', activation='sigmoid')(l2_p)
predictions = Flatten()(l3_c)
predictions = tf.cast(predictions, dtype='float32')
model = Model(inputs=inputs, outputs=predictions)
opt = Adam(lr=0.0005)
print(model.summary())
def loss_fn(y_true, y_pred):
return tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_pred, y_true))
model.compile(optimizer=opt,
loss=loss_fn,
metrics=['accuracy'])
return model
我在 TensorFlow Federated 中收到此错误。
Traceback (most recent call last):
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 270, in report
keras_metric = metric_type.from_config(metric_config)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 594, in from_config
return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'fn'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/fl.py", line 203, in <module>
quantization_part = FedAvgQ.build_federated_averaging_process(model_fn)
File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/new_fedavg_keras.py", line 195, in build_federated_averaging_process
stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py", line 351, in build_model_delta_optimizer_process
dummy_model_for_metadata = model_utils.enhance(model_fn())
File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/fl.py", line 196, in model_fn
return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 216, in from_compiled_keras_model
return model_utils.enhance(_TrainableKerasModel(keras_model, dummy_tensors))
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 491, in __init__
inner_model.loss_weights, inner_model.metrics)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 381, in __init__
federated_output, federated_local_outputs_type)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/api/computations.py", line 223, in federated_computation
return computation_wrapper_instances.federated_computation_wrapper(*args)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 410, in __call__
self._wrapper_fn)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 103, in _wrap
concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper_instances.py", line 78, in _federated_computation_wrapper_fn
suggested_name=name))
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/federated_computation_utils.py", line 76, in zero_or_one_arg_fn_to_building_block
context_stack))
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 652, in <lambda>
return lambda arg: _call(fn, parameter_type, arg)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 645, in _call
return fn(arg)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 377, in federated_output
type(metric), metric.get_config(), variables)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 260, in federated_aggregate_keras_metric
@tff.tf_computation(member_type)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 415, in <lambda>
return lambda fn: _wrap(fn, arg_type, self._wrapper_fn)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 103, in _wrap
concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper_instances.py", line 44, in _tf_wrapper_fn
target_fn, parameter_type, ctx_stack)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/tensorflow_serialization.py", line 278, in serialize_py_fn_as_tf_computation
result = target(*args)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 652, in <lambda>
return lambda arg: _call(fn, parameter_type, arg)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 645, in _call
return fn(arg)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 278, in report
t=metric_type, c=metric_config, e=e))
TypeError: Caught exception trying to call `<class 'tensorflow.python.keras.metrics.MeanMetricWrapper'>.from_config()` with config {'name': 'accuracy', 'dtype': 'float32'}. Confirm that <class 'tensorflow.python.keras.metrics.MeanMetricWrapper'>.__init__() has an argument for each member of the config.
Exception: __init__() missing 1 required positional argument: 'fn'
我的数据集的标签是一种两个标签[0. 1.]
,我用于binary_crossentropy
损失函数。但是准确度会恢复错误。我确信它与多个标签有关。当我删除精度时,损失计算没有任何问题。任何帮助将不胜感激。