0

我正在尝试使用 tensorflow 和tf.data. 我想复制 的功能PIL.Image.Image.crop,其中可以传递负边界框值,以便用零扩展裁剪。

例如,如果我调用PIL.Image.Image.crop([-10, 0, img_height, img_width]),则图像的开头有 10 行额外的行,用零填充。

据我了解,在 tf.data 输入管道中使用 python 函数会显着减慢代码速度,因此我尝试使用 tensorflow 函数编写所有内容。我还想使用 API 已经提供的预取、批处理、改组等。

我使用 tensorflow 函数实现 PIL 裁剪的计划是预先分配一个(动态成形的)零张量并使用切片分配裁剪值。

这是我遇到的错误:NotImplementedError: Cannot convert a symbolic Tensor (args_2:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported

这是复制问题的最少代码:

import tensorflow as tf
ds2 = tf.data.Dataset.from_tensor_slices(np.arange(10))

def preproc(number):
    o = tf.zeros((number,number))
    # ...

ds2.map(preproc)

问题1:我该如何解决这个问题?

问题 2:我来自 PyTorch 背景,我对整个 tf.data 管道的复杂性感到困惑。为什么我要使用 tensorflow 函数才能使用所有不错的功能?


作为参考,这是我到目前为止的完整代码。

class MyPreprocessing:
    def __call__(self, *data):
        # Complete code omitted for simplicity. 
        # This is called somewhere:
        self._crop(...)

    def _crop(self, img, center, body_size, res):
        tl = tf.cast(center - body_size/2, dtype=tf.int32)
        br = tf.cast(center + body_size/2, dtype=tf.int32)

        height, width = tf.shape(img)[0], tf.shape(img)[1]

        crop_tl = tf.stack([
            tf.cond(tl[0] < 0, lambda: tf.constant(0, dtype=tf.int32), lambda: tl[0]),
            tf.cond(tl[1] < 0, lambda: tf.constant(0, dtype=tf.int32), lambda: tl[1])])
        crop_br = tf.stack([
            tf.cond(br[0] > height, lambda: height, lambda: br[0]),
            tf.cond(br[1] > width, lambda: width, lambda: br[1])])

        crop = tf.image.crop_to_bounding_box(
            img,
            crop_tl[0],
            crop_tl[1],
            crop_br[0] - crop_tl[1],
            crop_br[1] - crop_tl[1])

        new_tl = crop_tl - tl
        new_br = crop_br - tl

        # Error:
        new_img = tf.zeros((body_size, body_size, tf.constant(3)), dtype=tf.float32)
        new_img[new_tl[0]:new_br[0], new_tl[1]:new_br[1]].assign(crop)

        return tf.image.resize(new_img, (res, res))

编辑:完整的堆栈跟踪

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-2-84d20c31b9c9> in <module>
      8     # ...
      9 
---> 10 ds2.map(preproc)

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls, deterministic)
   1923         warnings.warn("The `deterministic` argument has no effect unless the "
   1924                       "`num_parallel_calls` argument is specified.")
-> 1925       return MapDataset(self, map_func, preserve_cardinality=True)
   1926     else:
   1927       return ParallelMapDataset(

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality, use_legacy_function)
   4481     self._use_inter_op_parallelism = use_inter_op_parallelism
   4482     self._preserve_cardinality = preserve_cardinality
-> 4483     self._map_func = StructuredFunctionWrapper(
   4484         map_func,
   4485         self._transformation_name(),

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
   3710     resource_tracker = tracking.ResourceTracker()
   3711     with tracking.resource_tracker_scope(resource_tracker):
-> 3712       self._function = fn_factory()
   3713       # There is no graph to add in eager mode.
   3714       add_to_graph &= not context.executing_eagerly()

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self, *args, **kwargs)
   3132          or `tf.Tensor` or `tf.TensorSpec`.
   3133     """
-> 3134     graph_function = self._get_concrete_function_garbage_collected(
   3135         *args, **kwargs)
   3136     graph_function._garbage_collector.release()  # pylint: disable=protected-access

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   3098       args, kwargs = None, None
   3099     with self._lock:
-> 3100       graph_function, _ = self._maybe_define_function(args, kwargs)
   3101       seen_names = set()
   3102       captured = object_identity.ObjectIdentitySet(

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3442 
   3443           self._function_cache.missed.add(call_context_key)
-> 3444           graph_function = self._create_graph_function(args, kwargs)
   3445           self._function_cache.primary[cache_key] = graph_function
   3446 

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3277     arg_names = base_arg_names + missing_arg_names
   3278     graph_function = ConcreteFunction(
-> 3279         func_graph_module.func_graph_from_py_func(
   3280             self._name,
   3281             self._python_function,

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    997         _, original_func = tf_decorator.unwrap(python_func)
    998 
--> 999       func_outputs = python_func(*func_args, **func_kwargs)
   1000 
   1001       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapped_fn(*args)
   3685           attributes=defun_kwargs)
   3686       def wrapped_fn(*args):  # pylint: disable=missing-docstring
-> 3687         ret = wrapper_helper(*args)
   3688         ret = structure.to_tensor_list(self._output_structure, ret)
   3689         return [ops.convert_to_tensor(t) for t in ret]

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapper_helper(*args)
   3615       if not _should_unpack(nested_args):
   3616         nested_args = (nested_args,)
-> 3617       ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
   3618       if _should_pack(ret):
   3619         ret = tuple(ret)

~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    693       except Exception as e:  # pylint:disable=broad-except
    694         if hasattr(e, 'ag_error_metadata'):
--> 695           raise e.ag_error_metadata.to_exception(e)
    696         else:
    697           raise

NotImplementedError: in user code:

    <ipython-input-2-84d20c31b9c9>:7 preproc  *
        o = tf.zeros((number,number))
    /home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper  **
        return target(*args, **kwargs)
    /home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:2911 wrapped
        tensor = fun(*args, **kwargs)
    /home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:2960 zeros
        output = _constant_if_small(zero, shape, dtype, name)
    /home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:2896 _constant_if_small
        if np.prod(shape) < 1000:
    <__array_function__ internals>:5 prod
        
    /home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/numpy/core/fromnumeric.py:3030 prod
        return _wrapreduction(a, np.multiply, 'prod', axis, dtype, out,
    /home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/numpy/core/fromnumeric.py:87 _wrapreduction
        return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
    /home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:867 __array__
        raise NotImplementedError(

    NotImplementedError: Cannot convert a symbolic Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported
4

0 回答 0