我正在使用 Python 3.7.7。和带有功能 API 和急切执行的 Tensorflow 2.1.0。
我正在尝试使用从 U-Net 预训练网络中提取的编码器进行自定义训练:
- 我没有编译就得到了 U-Net 模型。
- 我已将权重加载到模型中。
- 我已经从该模型中提取了编码器和解码器。
然后我想将编码器与此摘要一起使用:
Model: "encoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 200, 200, 1)] 0
_________________________________________________________________
conv1_1 (Conv2D) (None, 200, 200, 64) 1664
_________________________________________________________________
conv1_2 (Conv2D) (None, 200, 200, 64) 102464
_________________________________________________________________
pool1 (MaxPooling2D) (None, 100, 100, 64) 0
_________________________________________________________________
conv2_1 (Conv2D) (None, 100, 100, 96) 55392
_________________________________________________________________
conv2_2 (Conv2D) (None, 100, 100, 96) 83040
_________________________________________________________________
pool2 (MaxPooling2D) (None, 50, 50, 96) 0
_________________________________________________________________
conv3_1 (Conv2D) (None, 50, 50, 128) 110720
_________________________________________________________________
conv3_2 (Conv2D) (None, 50, 50, 128) 147584
_________________________________________________________________
pool3 (MaxPooling2D) (None, 25, 25, 128) 0
_________________________________________________________________
conv4_1 (Conv2D) (None, 25, 25, 256) 295168
_________________________________________________________________
conv4_2 (Conv2D) (None, 25, 25, 256) 1048832
_________________________________________________________________
pool4 (MaxPooling2D) (None, 12, 12, 256) 0
_________________________________________________________________
conv5_1 (Conv2D) (None, 12, 12, 512) 1180160
_________________________________________________________________
conv5_2 (Conv2D) (None, 12, 12, 512) 2359808
=================================================================
Total params: 5,384,832
Trainable params: 5,384,832
Non-trainable params: 0
_________________________________________________________________
我使用此功能进行自定义培训:
def train_encoder_unet_custom(model, dataset):
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
for episode in range(num_episodes):
selected = np.random.permutation(no_of_samples)[:num_shot + num_query]
# Create our Support Set.
support_set = np.array(dataset[selected[:num_shot]])
X_train = support_set[:,0,:]
y_train = support_set[:,1,:]
loss_value, grads = grad(model, X_train, y_train)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
grad
功能是:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
def loss(model, x, y, training):
# training=training is needed only if there are layers with different
# behavior during training versus inference (e.g. Dropout).
y_ = model(x, training=training)
return loss_object(y_true=y, y_pred=y_)
def grad(model, inputs, targets):
with tf.GradientTape() as tape:
loss_value = loss(model, inputs, targets, training=False)
return loss_value, tape.gradient(loss_value, model.trainable_variables)
但是当我尝试运行它时,我得到了错误:
InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [5,12,12,512] != values[1].shape = [5,25,25,256] [Op:Pack] name: packed
在loss
函数中,我检查了y_
变量的值。y_
是具有这些形状的 6 个元素的列表:
(5, 12, 12, 512)
(5, 25, 25, 256)
(5, 50, 50, 128)
(5, 100, 100, 96)
(5, 200, 200, 64)
(5, 200, 200, 1)
关于它发生了什么的任何想法?
如果您需要更多详细信息,请询问我。
这是完整的调用堆栈:
<ipython-input-133-22827956a9f6> in train_encoder_unet_custom(model, dataset, feat_type, show)
22 y_valid = query_set[:,1,:]
23
---> 24 loss_value, grads = grad(model, X_train, y_train)
25
26 optimizer.apply_gradients(zip(grads, model.trainable_variables))
<ipython-input-143-58ff4de686d6> in grad(model, inputs, targets)
10 def grad(model, inputs, targets):
11 with tf.GradientTape() as tape:
---> 12 loss_value = loss(model, inputs, targets, training=False)
13 return loss_value, tape.gradient(loss_value, model.trainable_variables)
<ipython-input-143-58ff4de686d6> in loss(model, x, y, training)
6 y_ = model(x, training=training)
7
----> 8 return loss_object(y_true=y, y_pred=y_)
9
10 def grad(model, inputs, targets):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in __call__(self, y_true, y_pred, sample_weight)
147 with K.name_scope(self._name_scope), graph_ctx:
148 ag_call = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
--> 149 losses = ag_call(y_true, y_pred)
150 return losses_utils.compute_weighted_loss(
151 losses, sample_weight, reduction=self._get_reduction())
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
253 try:
254 with conversion_ctx:
--> 255 return converted_call(f, args, kwargs, options=options)
256 except Exception as e: # pylint:disable=broad-except
257 if hasattr(e, 'ag_error_metadata'):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
455 if conversion.is_in_whitelist_cache(f, options):
456 logging.log(2, 'Whitelisted %s: from cache', f)
--> 457 return _call_unconverted(f, args, kwargs, options, False)
458
459 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
337
338 if kwargs is not None:
--> 339 return f(*args, **kwargs)
340 return f(*args)
341
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in call(self, y_true, y_pred)
251 y_pred, y_true)
252 ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx())
--> 253 return ag_fn(y_true, y_pred, **self._fn_kwargs)
254
255 def get_config(self):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
199 """Call target, and fall back on dispatchers if there is a TypeError."""
200 try:
--> 201 return target(*args, **kwargs)
202 except (TypeError, ValueError):
203 # Note: convert_to_eager_tensor currently raises a ValueError, not a
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in sparse_categorical_crossentropy(y_true, y_pred, from_logits, axis)
1562 Sparse categorical crossentropy loss value.
1563 """
-> 1564 y_pred = ops.convert_to_tensor_v2(y_pred)
1565 y_true = math_ops.cast(y_true, y_pred.dtype)
1566 return K.sparse_categorical_crossentropy(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in convert_to_tensor_v2(value, dtype, dtype_hint, name)
1380 name=name,
1381 preferred_dtype=dtype_hint,
-> 1382 as_ref=False)
1383
1384
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
1497
1498 if ret is None:
-> 1499 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1500
1501 if ret is NotImplemented:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in _autopacking_conversion_function(v, dtype, name, as_ref)
1500 elif dtype != inferred_dtype:
1501 v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype), v)
-> 1502 return _autopacking_helper(v, dtype, name or "packed")
1503
1504
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in _autopacking_helper(list_or_tuple, dtype, name)
1406 # checking.
1407 if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
-> 1408 return gen_array_ops.pack(list_or_tuple, name=name)
1409 must_pack = False
1410 converted_elems = []
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in pack(values, axis, name)
6457 return _result
6458 except _core._NotOkStatusException as e:
-> 6459 _ops.raise_from_not_ok_status(e, name)
6460 except _core._FallbackException:
6461 pass
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
6841 message = e.message + (" name: " + name if name is not None else "")
6842 # pylint: disable=protected-access
-> 6843 six.raise_from(core._status_to_exception(e.code, message), None)
6844 # pylint: enable=protected-access
6845
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)