40

我有一个非常重要的输入管道,from_generator非常适合......

dataset = tf.data.Dataset.from_generator(complex_img_label_generator,
                                        (tf.int32, tf.string))
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

wherecomplex_img_label_generator动态生成图像并返回一个表示(H, W, 3)图像和简单string标签的 numpy 数组。处理不是我可以表示为从文件和tf.image操作中读取的东西。

我的问题是关于如何使发电机平行化?我如何让这些生成器中的 N 个在它们自己的线程中运行。

一种想法是使用dataset.mapwithnum_parallel_calls来处理线程;但是地图在张量上运行......另一个想法是创建多个生成器,每个生成器都有自己的prefetch并以某种方式加入它们,但我看不到如何加入 N 个生成器流?

我可以遵循任何规范的例子吗?

4

3 回答 3

29

事实证明,Dataset.map如果我使生成器超轻量级(仅生成元数据),然后将实际的重照明移动到无状态函数中,我可以使用。这样我就可以将繁重的部分与.map使用py_func.

作品; 但感觉有点笨拙......能够添加num_parallel_callsfrom_generator:)

def pure_numpy_and_pil_complex_calculation(metadata, label):
  # some complex pil and numpy work nothing to do with tf
  ...

dataset = tf.data.Dataset.from_generator(lightweight_generator,
                                         output_types=(tf.string,   # metadata
                                                       tf.string))  # label

def wrapped_complex_calulation(metadata, label):
  return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
                    inp = (metadata, label),
                    Tout = (tf.uint8,    # (H,W,3) img
                            tf.string))  # label
dataset = dataset.map(wrapped_complex_calulation,
                      num_parallel_calls=8)

dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
于 2017-11-03T05:52:41.693 回答
10

我正在from_indexabletf.data.Dataset https://github.com/tensorflow/tensorflow/issues/14448工作

for的优点from_indexable是可以并行化,而python生成器不能并行化。

该函数from_indexable生成一个tf.data.range,将可索引包装在一个通用tf.py_func并调用映射。

对于那些现在想要 a 的人from_indexable,这里是 lib 代码

import tensorflow as tf
import numpy as np

from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import nest

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args):
            nonlocal output_shapes

            flat_output_types = nest.flatten(output_types)
            flat_values = tf.py_func(
                func, 
                inp=args, 
                Tout=flat_output_types,
                stateful=stateful, name=name
            )
            if output_shapes is not None:
                # I am not sure if this is nessesary
                output_shapes = nest.map_structure_up_to(
                    output_types, tensor_shape.as_shape, output_shapes)
                flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
                for ret_t, shape in zip(flat_values, flattened_shapes):
                    ret_t.set_shape(shape)
            return nest.pack_sequence_as(output_types, flat_values)
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)

这里有一个例子(注意:from_indexable有一个 num_parallel_calls argument

class PyDataSet:
    def __len__(self):
        return 20

    def __getitem__(self, item):
        return np.random.normal(size=(item+1, 10))

ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
it = ds.make_one_shot_iterator()
entry = it.get_next()
with tf.Session() as sess:
    print(sess.run(entry).shape)
    print(sess.run(entry).shape)

2018 年 6 月 10 日更新:由于https://github.com/tensorflow/tensorflow/pull/15121被合并,代码from_indexable简化为:

import tensorflow as tf

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args, **kwargs):
            return tf.contrib.framework.py_func(
                func=func, 
                args=args, kwargs=kwargs, 
                output_types=output_types, output_shapes=output_shapes, 
                stateful=stateful, name=name
            )
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
于 2017-12-19T10:30:21.363 回答
5

将完成的工作限制在generator最低限度并使用 a 并行化昂贵的处理map是明智的。

或者,您可以使用以下方式“加入”多个生成器parallel_interleave

定义生成器(n):
  # 返回第 n 个生成器函数

定义数据集(n):
  return tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))

# 其中 N 是您使用的生成器的数量
于 2017-12-18T18:04:58.783 回答