我正在尝试迁移学习,在图像分类问题上,在谷歌 colab 上,当我运行这段代码时:
# Setup input shape to the model
INPUT_SHAPE = [None, 244, 244, 3] # batch, height, width, colour channels
# Setup output shape of the model
OUTPUT_SHAPE = 120
# Setup model URL form TensorFlow Hub
MODEL_URL = "https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4"
# Create a function which builds a Keras model
def create_model(input_shape=INPUT_SHAPE, output_shape=OUTPUT_SHAPE, model_url=MODEL_URL):
print("Building model with:", MODEL_URL)
# Setup the model layers
model = tf.keras.Sequential([
hub.KerasLayer(MODEL_URL), # Layer 1 (input layer)
tf.keras.layers.Dense(units=OUTPUT_SHAPE,
activation="softmax") # Layer 2 (output layer)
])
# Compile the model
model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"]
)
# Build the model
model.build(INPUT_SHAPE)
return model
model = create_model()
model.summary()
我收到了这个错误:
Building model with: https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-43-0fd4f47c95c0> in <module>()
----> 1 model = create_model()
2 model.summary()
5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
263 except Exception as e: # pylint:disable=broad-except
264 if hasattr(e, 'ag_error_metadata'):
--> 265 raise e.ag_error_metadata.to_exception(e)
266 else:
267 raise
ValueError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow_hub/keras_layer.py:229 call *
result = smart_cond.smart_cond(training,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/load.py:486 _call_attribute **
return instance.__call__(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:580 __call__
result = self._call(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:627 _call
self._initialize(args, kwds, add_initializers_to=initializers)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:506 _initialize
*args, **kwds))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2446 _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2777 _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2667 _create_graph_function
capture_by_value=self._capture_by_value),
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:981 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:441 wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/function_deserialization.py:261 restored_function_body
"\n\n".join(signature_descriptions)))
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (4 total):
* Tensor("inputs:0", shape=(None, 244, 244, 3), dtype=float32)
* False
* False
* 0.99
Keyword arguments: {}
Expected these arguments to match one of the following 4 option(s):
Option 1:
Positional arguments (4 total):
* TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')
* False
* True
* TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
Keyword arguments: {}
Option 2:
Positional arguments (4 total):
* TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')
* True
* False
* TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
Keyword arguments: {}
Option 3:
Positional arguments (4 total):
* TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')
* True
* True
* TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
Keyword arguments: {}
Option 4:
Positional arguments (4 total):
* TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')
* False
* False
* TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
Keyword arguments: {}
我尝试安装 tf-nightly 和旧版本的 tensorflow 以查看它是否会运行,但这不起作用。我还尝试了旧版本的 tensorflow_hub,这也导致了更多错误。我试图将笔记本恢复出厂设置并重试,但我得到了同样的错误。如果我注释掉,错误不会出现model.build(INPUT_SHAPE)
。除此之外,我不确定如何解决该问题。