我在 Keras 序列中使用 Tensorflow 概率层。但是,将模型保存为 json 然后加载它会引发异常。我正在使用custom_objects
能够加载自定义图层。这是重现错误的简约代码。
import tensorflow_probability as tfp
tfk = tf.keras
tfkl = tf.keras.layers
tfpl = tfp.layers
original_dim = 20
latent_dim = 2
model = tfk.Sequential([
tfkl.InputLayer(input_shape=original_dim),
tfkl.Dense(10, activation=tf.nn.leaky_relu),
tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(latent_dim), activation=None),
tfpl.MultivariateNormalTriL(latent_dim)
])
model_json = model.to_json()
with open("model.json", "w") as json_file:
json_file.write(model_json)
loaded_model = tfk.models.model_from_json(
open('model.json').read(),
custom_objects={
'leaky_relu': tf.nn.leaky_relu,
'MultivariateNormalTriL': tfpl.MultivariateNormalTriL
}
)
我得到以下异常:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-26-bbbeffd9e4be> in <module>
3 custom_objects={
4 'leaky_relu': tf.nn.leaky_relu,
----> 5 'MultivariateNormalTriL': tfpl.MultivariateNormalTriL
6 }
7 )
//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/saving/model_config.py in model_from_json(json_string, custom_objects)
94 config = json.loads(json_string)
95 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top
---> 96 return deserialize(config, custom_objects=custom_objects)
//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
87 module_objects=globs,
88 custom_objects=custom_objects,
---> 89 printable_module_name='layer')
//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
190 custom_objects=dict(
191 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 192 list(custom_objects.items())))
193 with CustomObjectScope(custom_objects):
194 return cls.from_config(cls_config)
//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/engine/sequential.py in from_config(cls, config, custom_objects)
350 for layer_config in layer_configs:
351 layer = layer_module.deserialize(layer_config,
--> 352 custom_objects=custom_objects)
353 model.add(layer)
354 if not model.inputs and build_input_shape:
//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
87 module_objects=globs,
88 custom_objects=custom_objects,
---> 89 printable_module_name='layer')
//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
190 custom_objects=dict(
191 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 192 list(custom_objects.items())))
193 with CustomObjectScope(custom_objects):
194 return cls.from_config(cls_config)
//anaconda3/envs/dl-env/lib/python3.7/site-packages/tensorflow_probability/python/layers/distribution_layer.py in from_config(cls, config, custom_objects)
875 config['arguments'][key] = np.array(arg_dict['value'])
876
--> 877 return cls(**config)
878
879 @classmethod
TypeError: __init__() missing 1 required positional argument: 'event_size'