1

从 tfrecords 文件导入数据时出现问题。tfrecords 中的每个样本都包含一个长度为 100 的特征向量和一个长度为 13 的 one-hot 标签向量。我使用下面的代码从 tfrecords 导入数据,参考官方指南https://www.tensorflow.org/程序员指南/数据集

def read_data(examples):
    features = {"features": tf.FixedLenFeature([seq_len], tf.int64),
               "label": tf.FixedLenFeature([category], tf.int64)}
    parsed_features = tf.parse_single_example(examples, features)
    return parsed_features['features'], parsed_features['label']

# get next batch of data and label
def next_batch(filename, batch_size):
    data = tf.data.TFRecordDataset(filename)
    data = data.map(read_data)
    data = data.batch(batch_size)
    iterator = data.make_one_shot_iterator()
    next_data, next_label = iterator.get_next()
    return next_data, next_label

with tf.Session() as sess:
    filetrain = 'train.tfrecords'
    next_data, next_label = next_batch(filetrain, num_example_train)
    sess.run(tf.global_variables_initializer())

    data = sess.run(next_data)
    label = sess.run(next_label)

问题是标签的顺序在批处理后变得错误。如果我删除代码“data = data.batch”,一切正常。

我认为一个可能的原因是特征和标签是独立批处理的。所以我尝试在批处理后解析示例,但收到错误“输入序列化必须是标量”。如果您知道如何处理此问题,请帮助我,非常感谢!

4

1 回答 1

1

我确定这是重复的,但我找不到其他问题,所以我会在这里回答。

sess.run()您的问题是两次调用数据和标签。每当您调用sess.run时,都会评估您的图形(即,提取一个新批次并在图形中运行,直到您作为第一个参数传递给的列表中的所有run张量值都已知)。

这样做,您的datalabel指的是两个不同的批次(因此它们看起来是错误的)。

您需要通过以下方式在同一个电话中获得它们:

data, label = sess.run([next_data, next_label])
于 2017-12-18T10:58:14.860 回答