0

我目前正在研究的神经网络接受稀疏张量作为输入。我正在从 TFRecord 读取我的数据,如下所示:

_, examples = tf.TFRecordReader(options=options).read_up_to(
    filename_queue, num_records=batch_size)

features = tf.parse_example(examples, features={
          'input_feat': tf.SparseFeature(index_key='input_feat_idx',
                                         value_key='input_feat_values',
                                         dtype=tf.int64,
                                         size=SIZE_FEATURE)})

它的工作原理就像一个魅力,但我正在查看tf.data对于很多任务来说看起来更方便的 API,我不确定如何像使用and那样读取tf.SparseTensor对象。任何想法?tf.RecordReadertf.parse_example()

4

1 回答 1

4

tf.SparseTensorTensorFlow 1.5 将在核心转换中添加原生支持。(如果您是当前可用pip install tf-nightly的,或者从 TensorFlow 的主分支上的源代码构建。)这意味着您可以编写如下管道:

# Create a dataset of string records from the input files.
dataset = tf.data.TFRecordReader(filenames)

# Convert each string record into a `tf.SparseTensor` representing a single example.
dataset = dataset.map(lambda record: tf.parse_single_example(
    record, features={'input_feat': tf.SparseFeature(index_key='input_feat_idx',
                                                     value_key='input_feat_values',
                                                     dtype=tf.int64,
                                                     size=SIZE_FEATURE)})

# Stack together up to `batch_size` consecutive elements into a `tf.SparseTensor`
# representing a batch of examples.
dataset = dataset.batch(batch_size)

# Create an iterator to access the elements of `dataset` sequentially.
iterator = dataset.make_one_shot_iterator()

# `next_element` is a `tf.SparseTensor`.
next_element = iterator.get_next()
于 2017-12-06T15:27:02.257 回答