0

我正在尝试用感知损失替换 Niftynet 的标准 RMSE 损失函数。不幸的是,每次我尝试跑步net_run train -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini --starting_iter 0 --max_iter 1000时,

我收到以下错误:

Traceback (most recent call last):
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/bin/net_run", line 8, in <module>
    sys.exit(main())
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/__init__.py", line 147, in main
    app_driver.run(app_driver.app)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 189, in run
    is_training_action=self.is_training_action)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 270, in create_graph
    outputs_collector, gradients_collector)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/contrib/regression_weighted_sampler/isample_regression.py", line 89, in connect_data_and_network
    self, outputs_collector, gradients_collector)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/application/regression_application.py", line 311, in connect_data_and_network
    weight_map=weight_map)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/base_layer.py", line 34, in __call__
    return self._op(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/template.py", line 385, in __call__
    return self._call_func(args, kwargs)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/template.py", line 355, in _call_func
    result = self._func(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 102, in layer_op
    parallel_iterations=1)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/map_fn.py", line 268, in map_fn
    maximum_iterations=n)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2753, in while_loop
    return_same_structure)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2245, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2170, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2705, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/map_fn.py", line 257, in compute
    packed_fn_values = fn(packed_values)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 91, in _batch_i_loss
    return tf.to_float(self._data_loss_func(**loss_params))
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 180, in rmse_loss
    pooling = None
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/applications/__init__.py", line 49, in wrapper
    return base_fun(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/applications/vgg16.py", line 32, in VGG16
    return vgg16.VGG16(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/keras_applications/vgg16.py", line 210, in VGG16
    model.load_weights(weights_path)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 182, in load_weights
    return super(Model, self).load_weights(filepath, by_name)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py", line 1373, in load_weights
    saving.load_weights_from_hdf5_group(f, self.layers)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py", line 693, in load_weights_from_hdf5_group
    K.batch_set_value(weight_value_tuples)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 3259, in batch_set_value
    get_session().run(assign_ops, feed_dict=feed_dict)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 486, in get_session
    _initialize_variables(session)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py", line 903, in _initialize_variables
    [variables_module.is_variable_initialized(v) for v in candidate_vars])
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
    run_metadata_ptr)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1165, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 488, in __init__
    self._assert_fetchable(graph, fetch.op)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 505, in _assert_fetchable
    % op.name)
tensorflow.python.framework.errors_impl.InaccessibleTensorError: Operation 'worker_0/loss_function/map/while/VarIsInitializedOp' has been marked as not fetchable. Typically this happens when it is defined in another function or code block. Use return values,explicit Python locals or TensorFlow collections to access it.

originally defined at:
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/application/regression_application.py", line 296, in connect_data_and_network
    loss_func = LossFunction(loss_type=self.action_param.loss_type)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 20, in __init__
    super(LossFunction, self).__init__(name=name)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/niftynet/layer/base_layer.py", line 26, in __init__
    self._op = tf.compat.v1.make_template(name, self.layer_op, create_scope_now_=True)
  File "/home_local/rajeevd/miniconda3/envs/my_virt_env/lib/python3.7/site-packages/tensorflow_core/python/ops/template.py", line 161, in make_template
    **kwargs)

我对 Niftynet 存储库的唯一重大更改是我用下面的代码替换了 loss_regression.py 文件中的“rmse_loss”函数。

def rmse_loss(prediction, ground_truth, weight_map=None):
    """
    :param prediction: the current prediction of the ground truth.
    :param ground_truth: the measurement you are approximating with regression.
    :param weight_map: a weight map for the cost function. .
    :return: sqrt(mean(differences squared))
    """
    """if weight_map is not None:
        residuals = tf.subtract(prediction, ground_truth)
        residuals = tf.multiply(residuals, residuals)
        residuals = tf.multiply(residuals, weight_map)
        return tf.sqrt(tf.reduce_mean(residuals) / tf.reduce_mean(weight_map))
    else:
        return tf.sqrt(tf.losses.mean_squared_error(prediction, ground_truth))"""

    prediction = tf.reshape(prediction, shape=[-1, 288])
    ground_truth = tf.reshape(ground_truth, shape=[-1, 288])

    prediction = tf.stack([prediction, prediction, prediction], axis=-1)
    ground_truth = tf.stack([ground_truth, ground_truth, ground_truth], axis=-1)

    prediction = tf.expand_dims(prediction, axis=0)
    ground_truth = tf.expand_dims(ground_truth, axis=0)

    prediction = tf.keras.applications.vgg16.preprocess_input(prediction)
    ground_truth = tf.keras.applications.vgg16.preprocess_input(ground_truth)

    pretrained_model = tf.keras.applications.vgg16.VGG16(include_top = False,
        weights = 'imagenet',
        pooling = None
        )

    pretrained_model.trainable = False

    prediction_features = pretrained_model(prediction)
    ground_truth_features = pretrained_model(ground_truth)

    return tf.sqrt(tf.losses.mean_squared_error(prediction_features, ground_truth_features))

函数开头的注释掉的代码是函数的原始代码。另外,据我了解,错误来自我创建“pretrained_model”(第 28 行)时的代码行。有谁知道如何防止这个错误?

感谢您的时间!

4

0 回答 0