0

我正在尝试在 Niftynet 中实现感知损失。

不幸的是,每当我运行以下命令来训练模型net_run train -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini时,都会得到以下错误输出:

  File "/home_local/rajeevd/miniconda3/envs/ML_research/bin/net_run", line 8, in <module>
    sys.exit(main())
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/__init__.py", line 149, in main
    app_driver.run(app_driver.app)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 191, in run
    is_training_action=self.is_training_action)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 272, in create_graph
    outputs_collector, gradients_collector)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/contrib/regression_weighted_sampler/isample_regression.py", line 89, in connect_data_and_network
    self, outputs_collector, gradients_collector)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/application/regression_application.py", line 313, in connect_data_and_network
    weight_map=weight_map)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/layer/base_layer.py", line 36, in __call__
    return self._op(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/template.py", line 385, in __call__
    return self._call_func(args, kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/template.py", line 355, in _call_func
    result = self._func(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 113, in layer_op
    parallel_iterations=1)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 602, in new_func
    return func(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 535, in new_func
    return func(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/map_fn.py", line 651, in map_fn_v2
    name=name)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 535, in new_func
    return func(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/map_fn.py", line 507, in map_fn
    maximum_iterations=n)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2816, in while_loop
    return_same_structure)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2298, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2223, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2768, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/map_fn.py", line 491, in compute
    result_value = autographed_fn(elems_value)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.autograph.pyct.error_utils.KeyError: in user code:

    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/layer/loss_regression.py:102 _batch_i_loss  *
        return tf.cast(self._data_loss_func(**loss_params), tf.float32)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:889 __call__  **
        result = self._call(*args, **kwds)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:957 _call
        filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/function.py:1974 _call_flat
        flat_outputs = forward_function.call(ctx, args_with_tangents)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/function.py:625 call
        executor_type=executor_type)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py:1221 partitioned_call
        op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:535 new_func
        return func(*args, **kwargs)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3496 create_op
        attrs, op_def, compute_device)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3565 _create_op_internal
        op_def=op_def)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2063 __init__
        self._control_flow_post_processing(input_tensors=inputs)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2081 _control_flow_post_processing
        self._control_flow_context.AddOp(self)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:1765 AddOp
        self._AddOpInternal(op)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:1777 _AddOpInternal
        if not op.inputs:
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2393 inputs
        pywrap_tf_session.GetOperationInputs(self._c_op)))
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3946 _get_tensor_by_tf_output
        op = self._get_operation_by_tf_operation(tf_output.oper)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3910 _get_operation_by_tf_operation
        return self._get_operation_by_name_unsafe(op_name)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3906 _get_operation_by_name_unsafe
        return self._nodes_by_name[name]

    KeyError: 'block1_conv1/kernel'

我对 repo 的唯一相关更改是我用下面的代码替换了 Niftynet 的 RMSE 损失函数。

vgg.trainable = False
outputs = vgg.get_layer('block4_conv2').output
model = tf.keras.Model([vgg.input], outputs)
for layer in model.layers:
    layer.trainable = False
@tf.function
def rmse_loss(prediction, ground_truth, weight_map=None):
    prediction = tf.reshape(prediction, shape=[-1, 320, 320, 3])
    ground_truth = tf.reshape(ground_truth, shape=[-1, 320, 320, 3])
    prediction = tf.keras.applications.vgg19.preprocess_input(prediction)
    ground_truth = tf.keras.applications.vgg19.preprocess_input(ground_truth)
    h1_list = model(prediction)
    h2_list = model(ground_truth)
    rc_loss = 0.0
    h1 = K.batch_flatten(h1_list)
    h2 = K.batch_flatten(h2_list)
    rc_loss = rc_loss + 1.0 * K.sum(K.square(h1 - h2), axis=-1)
    return rc_loss

另外,当我打电话时vgg.summary(),我得到:

Model: "vgg19"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 320, 320, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 320, 320, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 320, 320, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 160, 160, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 160, 160, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 160, 160, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 80, 80, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 80, 80, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 80, 80, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 80, 80, 256)       590080    
_________________________________________________________________
block3_conv4 (Conv2D)        (None, 80, 80, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 40, 40, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 40, 40, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 40, 40, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 40, 40, 512)       2359808   
_________________________________________________________________
block4_conv4 (Conv2D)        (None, 40, 40, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 20, 20, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block5_conv4 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 10, 10, 512)       0         
=================================================================
Total params: 20,024,384
Trainable params: 0
Non-trainable params: 20,024,384

因此,至少从我的角度来看,'block1_conv1' 的参数似乎在那里。有谁知道可能导致此关键错误的原因以及如何解决?

感谢您的时间!

4

0 回答 0