1

我想用 tf.estimator.Estimator 训练我的模式并通过 Dataset API 加载我的数据。因为我的数据,例如“mnist”,是一个数组(张量),所以我尝试用“tf.data”加载它。 Dataset.from_tensor_slices'。但我不知道如何在“input_fn”中初始化“make_initializable_iterator”。

如果我可以使用“make_one_shot_iterator”成功训练,但它在训练前加载缓慢。而《<a href="https://medium.com/onfido-tech/higher-level-apis-in-tensorflow-67bfb602e6c0" rel="nofollow noreferrer">Higher-Level APIs in TensorFlow》就是一个很好的例子'input_fn' 中的 'make_initializable_iterator',但它需要从 'input_fn' 返回一个 'iterator_initializer_hook' 给其他函数。我想知道还有其他更好或更优雅的方式吗?

    def input_fn():

    mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
    images = mnist_data.train.images.reshape([-1, 28, 28, 1])
    labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

    # Build dataset iterator
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat(None)  # Infinite iterations
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(100)
    iterator = dataset.make_one_shot_iterator()
    next_example = iterator.get_next()
    # Set runhook to initialize iterator

    return next_example
4

2 回答 2

7

在 TensorFlow 1.5 及更高版本中,tf.estimator.Estimator当您tf.data.Datasetinput_fn. 这使您可以编写以下代码,而不必担心初始化或挂钩:

def input_fn():
    mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
    images = mnist_data.train.images.reshape([-1, 28, 28, 1])
    labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

    # Build dataset.
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat(None)  # Infinite iterations
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(100)
    return dataset
于 2018-02-05T17:18:35.363 回答
0

在您的代码中,添加以下内容:

      self.hooks.append(utils_hooks.DatasetHook(iter))

在 run_loop.py 中,在调用你的 fn 之前,添加这个

 for hook in dataset_hooks:
        sess.run(hook.iterator().initializer)

那么,应该没问题。

于 2019-07-22T08:29:49.990 回答