我正在使用 Spark 作业来生成一个TFRecord
文件,该文件将是我的(单词,计数)对的词汇文件。
我想使用 Dataset API 一次加载整个文件,因为我的词汇文件可能位于 HDFS 上,并且可能被拆分为多个物理文件。也就是说,我发现它非常不直观。到目前为止,这是我的代码:
def parse(example):
parsed = tf.parse_single_example(example, features={
'token': tf.FixedLenFeature([], dtype=tf.string),
'count': tf.FixedLenFeature([], dtype=tf.int64)
})
return parsed['token'], parsed['count']
filenames = tf.gfile.Glob(filenames)
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parse)
dataset = dataset.batch(MAX_VOCAB_FILE)
iterator = dataset.make_one_shot_iterator()
token, token_count = iterator.get_next()
使用一个巨大的、固定的预先批量大小是我能想到的在shape=(num_entries,)
. 它似乎也运行得很慢。
有没有更好的办法?