我正在尝试在 Niftynet 中实现感知损失。
不幸的是,每当我运行以下命令来训练模型net_run train -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini
时,都会得到以下错误输出:
File "/home_local/rajeevd/miniconda3/envs/ML_research/bin/net_run", line 8, in <module>
sys.exit(main())
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/__init__.py", line 149, in main
app_driver.run(app_driver.app)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 191, in run
is_training_action=self.is_training_action)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 272, in create_graph
outputs_collector, gradients_collector)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/contrib/regression_weighted_sampler/isample_regression.py", line 89, in connect_data_and_network
self, outputs_collector, gradients_collector)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/application/regression_application.py", line 313, in connect_data_and_network
weight_map=weight_map)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/layer/base_layer.py", line 36, in __call__
return self._op(*args, **kwargs)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/template.py", line 385, in __call__
return self._call_func(args, kwargs)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/template.py", line 355, in _call_func
result = self._func(*args, **kwargs)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 113, in layer_op
parallel_iterations=1)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 602, in new_func
return func(*args, **kwargs)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 535, in new_func
return func(*args, **kwargs)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/map_fn.py", line 651, in map_fn_v2
name=name)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 535, in new_func
return func(*args, **kwargs)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/map_fn.py", line 507, in map_fn
maximum_iterations=n)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2816, in while_loop
return_same_structure)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2298, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2223, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2768, in <lambda>
body = lambda i, lv: (i + 1, orig_body(*lv))
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/map_fn.py", line 491, in compute
result_value = autographed_fn(elems_value)
File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.autograph.pyct.error_utils.KeyError: in user code:
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/layer/loss_regression.py:102 _batch_i_loss *
return tf.cast(self._data_loss_func(**loss_params), tf.float32)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:889 __call__ **
result = self._call(*args, **kwds)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:957 _call
filtered_flat_args, self._concrete_stateful_fn.captured_inputs) # pylint: disable=protected-access
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/function.py:1974 _call_flat
flat_outputs = forward_function.call(ctx, args_with_tangents)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/function.py:625 call
executor_type=executor_type)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py:1221 partitioned_call
op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:535 new_func
return func(*args, **kwargs)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3496 create_op
attrs, op_def, compute_device)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3565 _create_op_internal
op_def=op_def)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2063 __init__
self._control_flow_post_processing(input_tensors=inputs)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2081 _control_flow_post_processing
self._control_flow_context.AddOp(self)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:1765 AddOp
self._AddOpInternal(op)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:1777 _AddOpInternal
if not op.inputs:
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2393 inputs
pywrap_tf_session.GetOperationInputs(self._c_op)))
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3946 _get_tensor_by_tf_output
op = self._get_operation_by_tf_operation(tf_output.oper)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3910 _get_operation_by_tf_operation
return self._get_operation_by_name_unsafe(op_name)
/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3906 _get_operation_by_name_unsafe
return self._nodes_by_name[name]
KeyError: 'block1_conv1/kernel'
我对 repo 的唯一相关更改是我用下面的代码替换了 Niftynet 的 RMSE 损失函数。
vgg.trainable = False
outputs = vgg.get_layer('block4_conv2').output
model = tf.keras.Model([vgg.input], outputs)
for layer in model.layers:
layer.trainable = False
@tf.function
def rmse_loss(prediction, ground_truth, weight_map=None):
prediction = tf.reshape(prediction, shape=[-1, 320, 320, 3])
ground_truth = tf.reshape(ground_truth, shape=[-1, 320, 320, 3])
prediction = tf.keras.applications.vgg19.preprocess_input(prediction)
ground_truth = tf.keras.applications.vgg19.preprocess_input(ground_truth)
h1_list = model(prediction)
h2_list = model(ground_truth)
rc_loss = 0.0
h1 = K.batch_flatten(h1_list)
h2 = K.batch_flatten(h2_list)
rc_loss = rc_loss + 1.0 * K.sum(K.square(h1 - h2), axis=-1)
return rc_loss
另外,当我打电话时vgg.summary()
,我得到:
Model: "vgg19"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 320, 320, 3)] 0
_________________________________________________________________
block1_conv1 (Conv2D) (None, 320, 320, 64) 1792
_________________________________________________________________
block1_conv2 (Conv2D) (None, 320, 320, 64) 36928
_________________________________________________________________
block1_pool (MaxPooling2D) (None, 160, 160, 64) 0
_________________________________________________________________
block2_conv1 (Conv2D) (None, 160, 160, 128) 73856
_________________________________________________________________
block2_conv2 (Conv2D) (None, 160, 160, 128) 147584
_________________________________________________________________
block2_pool (MaxPooling2D) (None, 80, 80, 128) 0
_________________________________________________________________
block3_conv1 (Conv2D) (None, 80, 80, 256) 295168
_________________________________________________________________
block3_conv2 (Conv2D) (None, 80, 80, 256) 590080
_________________________________________________________________
block3_conv3 (Conv2D) (None, 80, 80, 256) 590080
_________________________________________________________________
block3_conv4 (Conv2D) (None, 80, 80, 256) 590080
_________________________________________________________________
block3_pool (MaxPooling2D) (None, 40, 40, 256) 0
_________________________________________________________________
block4_conv1 (Conv2D) (None, 40, 40, 512) 1180160
_________________________________________________________________
block4_conv2 (Conv2D) (None, 40, 40, 512) 2359808
_________________________________________________________________
block4_conv3 (Conv2D) (None, 40, 40, 512) 2359808
_________________________________________________________________
block4_conv4 (Conv2D) (None, 40, 40, 512) 2359808
_________________________________________________________________
block4_pool (MaxPooling2D) (None, 20, 20, 512) 0
_________________________________________________________________
block5_conv1 (Conv2D) (None, 20, 20, 512) 2359808
_________________________________________________________________
block5_conv2 (Conv2D) (None, 20, 20, 512) 2359808
_________________________________________________________________
block5_conv3 (Conv2D) (None, 20, 20, 512) 2359808
_________________________________________________________________
block5_conv4 (Conv2D) (None, 20, 20, 512) 2359808
_________________________________________________________________
block5_pool (MaxPooling2D) (None, 10, 10, 512) 0
=================================================================
Total params: 20,024,384
Trainable params: 0
Non-trainable params: 20,024,384
因此,至少从我的角度来看,'block1_conv1' 的参数似乎在那里。有谁知道可能导致此关键错误的原因以及如何解决?
感谢您的时间!