我已将数据集划分为 10 个 tfrecords 文件,我想从每个文件中读取 100 个数据点,以创建一批 10 个 100 个数据点的序列。我使用以下函数来做到这一点。来自 tfrecords 的数据加载时间开始缓慢,然后达到大约 0.65 秒,在 100-200 sess.run 调用后增加到大约 10 秒。您能否指出任何可能有助于减少阅读时间的错误或建议?此外,我提到的行为有时会变得更加不稳定。
def get_data(mini_batch_size):
data = []
for i in range(mini_batch_size):
filename_queue = tf.train.string_input_producer([data_path + 'Features' + str(i) + '.tfrecords'])
reader = tf.TFRecordReader()
_, serialized_example = reader.read_up_to(filename_queue,step_size)
features = tf.parse_example(serialized_example,features={'feature_raw': tf.VarLenFeature(dtype=tf.float32)})
feature = features['feature_raw'].values
feature = tf.reshape(feature,[step_size, ConvLSTM.H, ConvLSTM.W, ConvLSTM.Di])
data.append(feature)
return tf.stack(data)
即使我按如下方式从单个文件中提取,我也观察到相同的行为。此外,增加 num_threads 也无济于事。
with tf.device('/cpu:0'):
filename_queue = tf.train.string_input_producer(['./Data/TFRecords/Features' + str(i) + '.tfrecords'])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
batch_serialized_example = tf.train.batch([serialized_example], batch_size=100, num_threads=1, capacity=100)
features = tf.parse_example(batch_serialized_example,features={'feature_raw': tf.VarLenFeature(dtype=tf.float32)})
feature = features['feature_raw'].values
data.append(feature)
data = tf.stack(data)
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1,inter_op_parallelism_threads=1,allow_soft_placement=True))
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(1000):
t = time.time()
D = sess.run(data)
print(time.time()-t)