1

我正在使用 Spark 作业来生成一个TFRecord文件,该文件将是我的(单词,计数)对的词汇文件。

我想使用 Dataset API 一次加载整个文件,因为我的词汇文件可能位于 HDFS 上,并且可能被拆分为多个物理文件。也就是说,我发现它非常不直观。到目前为止,这是我的代码:

def parse(example):
    parsed = tf.parse_single_example(example, features={
        'token': tf.FixedLenFeature([], dtype=tf.string),
        'count': tf.FixedLenFeature([], dtype=tf.int64)
    })
    return parsed['token'], parsed['count']

filenames = tf.gfile.Glob(filenames)
dataset   = tf.data.TFRecordDataset(filenames)
dataset   = dataset.map(parse)
dataset   = dataset.batch(MAX_VOCAB_FILE)
iterator  = dataset.make_one_shot_iterator()

token, token_count = iterator.get_next()

使用一个巨大的、固定的预先批量大小是我能想到的在shape=(num_entries,). 它似乎也运行得很慢。

有没有更好的办法?

4

1 回答 1

0

我不得不在某一时刻做类似的事情。这并不直观,但在导入数据的 TF 文档中隐藏了一个答案。

为了使我的解决方案更易于理解,我假设您的代码包含在一个名为的方法中get_batch

def get_batch(filenames, batch_size):
    # your code
    return token, token_count

参数在您的示例中batch_size替换的位置。MAX_VOCAB_FILE如果要将所有连续(token, token_count)对打印到标准输出,一次一行,您可以这样做:

with tf.Session() as sess:
    example = get_batch(filenames, 1)
    while True:
        print(sess.run(example))

然后迭代器耗尽,循环以OutOfRangeError. 发出的每条记录都是一对重铸为 numpy 对象的张量。

希望这是足够的细节来做一些有用的事情。

于 2018-01-13T19:06:52.140 回答