23

我正在使用Tensorflow v1.3 中的 Dataset API。这很棒。可以使用此处描述的函数映射数据集。我很想知道如何传递具有附加参数的函数,例如arg1

def _parse_function(example_proto, arg1):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

当然,

dataset = dataset.map(_parse_function)

行不通,因为没有办法通过arg1

4

3 回答 3

39

这是一个使用 lambda 表达式来包装我们想要传递参数的函数的示例:

import tensorflow as tf
def fun(x, arg):
    return x * arg

my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))

在上面,提供给的函数的签名map必须与我们数据集的内容相匹配。所以我们必须编写我们的 lambda 表达式来匹配它。这里很简单,因为数据集中只包含一个元素,x即包含 0 到 4 范围内的元素。

如有必要,您可以从数据集外部传入任意数量的外部参数:ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3)等等。

为了验证上述方法是否有效,我们可以观察到映射确实将每个数据集元素乘以 2:

iterator = ds.make_initializable_iterator()
next_x = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)

    while True:
      try:
        print(sess.run(next_x))
      except tf.errors.OutOfRangeError:
        break

输出:

0
2
4
6
8
于 2018-02-01T19:40:08.630 回答
7

您也可以使用Partial函数来包装您的参数:

def _parse_function(arg1, example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

更改函数的参数顺序以适应偏向性,然后您可以使用如下参数值包装函数:

from functools import partial

arg1 = ...
dataset = dataset.map(partial(_parse_function, arg1))
于 2019-10-22T13:43:17.580 回答
1

另一种解决方案是使用类包装器。在下面的代码中,我将参数shape传递给了 parse 函数。

class MyDataSets:

    def __init__(self, shape):
        self.shape = shape

    def parse_sample(self.sample):
        features = { ... }
        f = tf.parse_example([example], features=features)

        image_raw = tf.decode_raw(f['image_raw'], tf.uint8)
        image = image.reshape(image_raw, self.shape)

        label = tf.cast(f['label'], tf.int32)

        return image, label

    def init(self):
        ds = tf.data.TFRecordDataSets(...)
        ds = ds.map(self.parse_sample)
        ...
        return ds.make_initializable_iterator()
于 2019-03-26T00:43:34.227 回答