我正在尝试用感知损失替换 Niftynet 的标准 RMSE 损失函数。不幸的是,每次我尝试跑步net_run train -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini --starting_iter 0 --max_iter 1000
时,
我收到以下错误:
Traceback (most recent call last):
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/bin/net_run", line 8, in <module>
sys.exit(main())
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/__init__.py", line 147, in main
app_driver.run(app_driver.app)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 189, in run
is_training_action=self.is_training_action)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 270, in create_graph
outputs_collector, gradients_collector)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/contrib/regression_weighted_sampler/isample_regression.py", line 89, in connect_data_and_network
self, outputs_collector, gradients_collector)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/application/regression_application.py", line 311, in connect_data_and_network
weight_map=weight_map)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/base_layer.py", line 34, in __call__
return self._op(*args, **kwargs)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/template.py", line 385, in __call__
return self._call_func(args, kwargs)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/template.py", line 355, in _call_func
result = self._func(*args, **kwargs)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 102, in layer_op
parallel_iterations=1)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/map_fn.py", line 268, in map_fn
maximum_iterations=n)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2753, in while_loop
return_same_structure)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2245, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2170, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2705, in <lambda>
body = lambda i, lv: (i + 1, orig_body(*lv))
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/map_fn.py", line 257, in compute
packed_fn_values = fn(packed_values)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 91, in _batch_i_loss
return tf.to_float(self._data_loss_func(**loss_params))
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 180, in rmse_loss
pooling = None
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/applications/__init__.py", line 49, in wrapper
return base_fun(*args, **kwargs)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/applications/vgg16.py", line 32, in VGG16
return vgg16.VGG16(*args, **kwargs)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/keras_applications/vgg16.py", line 210, in VGG16
model.load_weights(weights_path)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 182, in load_weights
return super(Model, self).load_weights(filepath, by_name)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py", line 1373, in load_weights
saving.load_weights_from_hdf5_group(f, self.layers)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py", line 693, in load_weights_from_hdf5_group
K.batch_set_value(weight_value_tuples)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 3259, in batch_set_value
get_session().run(assign_ops, feed_dict=feed_dict)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 486, in get_session
_initialize_variables(session)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 903, in _initialize_variables
[variables_module.is_variable_initialized(v) for v in candidate_vars])
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
run_metadata_ptr)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1165, in _run
self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 488, in __init__
self._assert_fetchable(graph, fetch.op)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 505, in _assert_fetchable
% op.name)
tensorflow.python.framework.errors_impl.InaccessibleTensorError: Operation 'worker_0/loss_function/map/while/VarIsInitializedOp' has been marked as not fetchable. Typically this happens when it is defined in another function or code block. Use return values,explicit Python locals or TensorFlow collections to access it.
originally defined at:
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/application/regression_application.py", line 296, in connect_data_and_network
loss_func = LossFunction(loss_type=self.action_param.loss_type)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 20, in __init__
super(LossFunction, self).__init__(name=name)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/base_layer.py", line 26, in __init__
self._op = tf.compat.v1.make_template(name, self.layer_op, create_scope_now_=True)
File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/template.py", line 161, in make_template
**kwargs)
我对 Niftynet 存储库的唯一重大更改是我用下面的代码替换了 loss_regression.py 文件中的“rmse_loss”函数。
def rmse_loss(prediction, ground_truth, weight_map=None):
"""
:param prediction: the current prediction of the ground truth.
:param ground_truth: the measurement you are approximating with regression.
:param weight_map: a weight map for the cost function. .
:return: sqrt(mean(differences squared))
"""
"""if weight_map is not None:
residuals = tf.subtract(prediction, ground_truth)
residuals = tf.multiply(residuals, residuals)
residuals = tf.multiply(residuals, weight_map)
return tf.sqrt(tf.reduce_mean(residuals) / tf.reduce_mean(weight_map))
else:
return tf.sqrt(tf.losses.mean_squared_error(prediction, ground_truth))"""
prediction = tf.reshape(prediction, shape=[-1, 288])
ground_truth = tf.reshape(ground_truth, shape=[-1, 288])
prediction = tf.stack([prediction, prediction, prediction], axis=-1)
ground_truth = tf.stack([ground_truth, ground_truth, ground_truth], axis=-1)
prediction = tf.expand_dims(prediction, axis=0)
ground_truth = tf.expand_dims(ground_truth, axis=0)
prediction = tf.keras.applications.vgg16.preprocess_input(prediction)
ground_truth = tf.keras.applications.vgg16.preprocess_input(ground_truth)
pretrained_model = tf.keras.applications.vgg16.VGG16(include_top = False,
weights = 'imagenet',
pooling = None
)
pretrained_model.trainable = False
prediction_features = pretrained_model(prediction)
ground_truth_features = pretrained_model(ground_truth)
return tf.sqrt(tf.losses.mean_squared_error(prediction_features, ground_truth_features))
函数开头的注释掉的代码是函数的原始代码。另外,据我了解,错误来自我创建“pretrained_model”(第 28 行)时的代码行。有谁知道如何防止这个错误?
感谢您的时间!