我正在尝试使用 tensorflow 和tf.data
. 我想复制 的功能PIL.Image.Image.crop
,其中可以传递负边界框值,以便用零扩展裁剪。
例如,如果我调用PIL.Image.Image.crop([-10, 0, img_height, img_width])
,则图像的开头有 10 行额外的行,用零填充。
据我了解,在 tf.data 输入管道中使用 python 函数会显着减慢代码速度,因此我尝试使用 tensorflow 函数编写所有内容。我还想使用 API 已经提供的预取、批处理、改组等。
我使用 tensorflow 函数实现 PIL 裁剪的计划是预先分配一个(动态成形的)零张量并使用切片分配裁剪值。
这是我遇到的错误:NotImplementedError: Cannot convert a symbolic Tensor (args_2:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported
这是复制问题的最少代码:
import tensorflow as tf
ds2 = tf.data.Dataset.from_tensor_slices(np.arange(10))
def preproc(number):
o = tf.zeros((number,number))
# ...
ds2.map(preproc)
问题1:我该如何解决这个问题?
问题 2:我来自 PyTorch 背景,我对整个 tf.data 管道的复杂性感到困惑。为什么我要使用 tensorflow 函数才能使用所有不错的功能?
作为参考,这是我到目前为止的完整代码。
class MyPreprocessing:
def __call__(self, *data):
# Complete code omitted for simplicity.
# This is called somewhere:
self._crop(...)
def _crop(self, img, center, body_size, res):
tl = tf.cast(center - body_size/2, dtype=tf.int32)
br = tf.cast(center + body_size/2, dtype=tf.int32)
height, width = tf.shape(img)[0], tf.shape(img)[1]
crop_tl = tf.stack([
tf.cond(tl[0] < 0, lambda: tf.constant(0, dtype=tf.int32), lambda: tl[0]),
tf.cond(tl[1] < 0, lambda: tf.constant(0, dtype=tf.int32), lambda: tl[1])])
crop_br = tf.stack([
tf.cond(br[0] > height, lambda: height, lambda: br[0]),
tf.cond(br[1] > width, lambda: width, lambda: br[1])])
crop = tf.image.crop_to_bounding_box(
img,
crop_tl[0],
crop_tl[1],
crop_br[0] - crop_tl[1],
crop_br[1] - crop_tl[1])
new_tl = crop_tl - tl
new_br = crop_br - tl
# Error:
new_img = tf.zeros((body_size, body_size, tf.constant(3)), dtype=tf.float32)
new_img[new_tl[0]:new_br[0], new_tl[1]:new_br[1]].assign(crop)
return tf.image.resize(new_img, (res, res))
编辑:完整的堆栈跟踪
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-2-84d20c31b9c9> in <module>
8 # ...
9
---> 10 ds2.map(preproc)
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls, deterministic)
1923 warnings.warn("The `deterministic` argument has no effect unless the "
1924 "`num_parallel_calls` argument is specified.")
-> 1925 return MapDataset(self, map_func, preserve_cardinality=True)
1926 else:
1927 return ParallelMapDataset(
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, use_inter_op_parallelism, preserve_cardinality, use_legacy_function)
4481 self._use_inter_op_parallelism = use_inter_op_parallelism
4482 self._preserve_cardinality = preserve_cardinality
-> 4483 self._map_func = StructuredFunctionWrapper(
4484 map_func,
4485 self._transformation_name(),
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
3710 resource_tracker = tracking.ResourceTracker()
3711 with tracking.resource_tracker_scope(resource_tracker):
-> 3712 self._function = fn_factory()
3713 # There is no graph to add in eager mode.
3714 add_to_graph &= not context.executing_eagerly()
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self, *args, **kwargs)
3132 or `tf.Tensor` or `tf.TensorSpec`.
3133 """
-> 3134 graph_function = self._get_concrete_function_garbage_collected(
3135 *args, **kwargs)
3136 graph_function._garbage_collector.release() # pylint: disable=protected-access
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
3098 args, kwargs = None, None
3099 with self._lock:
-> 3100 graph_function, _ = self._maybe_define_function(args, kwargs)
3101 seen_names = set()
3102 captured = object_identity.ObjectIdentitySet(
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442
3443 self._function_cache.missed.add(call_context_key)
-> 3444 graph_function = self._create_graph_function(args, kwargs)
3445 self._function_cache.primary[cache_key] = graph_function
3446
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3277 arg_names = base_arg_names + missing_arg_names
3278 graph_function = ConcreteFunction(
-> 3279 func_graph_module.func_graph_from_py_func(
3280 self._name,
3281 self._python_function,
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997 _, original_func = tf_decorator.unwrap(python_func)
998
--> 999 func_outputs = python_func(*func_args, **func_kwargs)
1000
1001 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapped_fn(*args)
3685 attributes=defun_kwargs)
3686 def wrapped_fn(*args): # pylint: disable=missing-docstring
-> 3687 ret = wrapper_helper(*args)
3688 ret = structure.to_tensor_list(self._output_structure, ret)
3689 return [ops.convert_to_tensor(t) for t in ret]
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapper_helper(*args)
3615 if not _should_unpack(nested_args):
3616 nested_args = (nested_args,)
-> 3617 ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
3618 if _should_pack(ret):
3619 ret = tuple(ret)
~/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
693 except Exception as e: # pylint:disable=broad-except
694 if hasattr(e, 'ag_error_metadata'):
--> 695 raise e.ag_error_metadata.to_exception(e)
696 else:
697 raise
NotImplementedError: in user code:
<ipython-input-2-84d20c31b9c9>:7 preproc *
o = tf.zeros((number,number))
/home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:206 wrapper **
return target(*args, **kwargs)
/home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:2911 wrapped
tensor = fun(*args, **kwargs)
/home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:2960 zeros
output = _constant_if_small(zero, shape, dtype, name)
/home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:2896 _constant_if_small
if np.prod(shape) < 1000:
<__array_function__ internals>:5 prod
/home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/numpy/core/fromnumeric.py:3030 prod
return _wrapreduction(a, np.multiply, 'prod', axis, dtype, out,
/home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/numpy/core/fromnumeric.py:87 _wrapreduction
return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
/home/benjs/Documents/projects/hpe/venv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:867 __array__
raise NotImplementedError(
NotImplementedError: Cannot convert a symbolic Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported