是否可以从文件中获取记录总数.tfrecords
?与此相关,人们通常如何跟踪训练模型时经过的 epoch 数?虽然我们可以指定batch_size
and num_of_epochs
,但我不确定是否可以直接获得诸如current epoch
每个 epoch 的批次数等值 - 以便我可以更好地控制训练的进度。目前,我只是使用肮脏的技巧来计算这个,因为我事先知道我的 .tfrecords 文件中有多少条记录以及我的小批量的大小。感谢任何帮助..
问问题
20472 次
5 回答
33
要计算记录数,您应该能够使用tf.python_io.tf_record_iterator
.
c = 0
for fn in tf_records_filenames:
for record in tf.python_io.tf_record_iterator(fn):
c += 1
为了跟踪模型训练,张量板就派上用场了。
于 2016-11-07T20:10:43.607 回答
25
不,这是不可能的。TFRecord不存储有关存储在其中的数据的任何元数据。这个文件
表示一系列(二进制)字符串。该格式不是随机访问,因此适用于流式传输大量数据,但不适用于需要快速分片或其他非顺序访问的情况。
如果需要,您可以手动存储此元数据或使用record_iterator来获取数字(您需要遍历您拥有的所有记录:
sum(1 for _ in tf.python_io.tf_record_iterator(file_name))
如果你想知道当前的 epoch,你可以从 tensorboard 或通过打印循环中的数字来做到这一点。
于 2017-07-03T22:07:25.350 回答
10
由于tf.io.tf_record_iterator被弃用,萨尔瓦多·达利的伟大答案现在应该阅读
tf.enable_eager_execution()
sum(1 for _ in tf.data.TFRecordDataset(file_name))
于 2019-10-16T12:31:24.733 回答
6
根据tf_record_iterator上的弃用警告,我们还可以使用急切执行来计算记录。
#!/usr/bin/env python
from __future__ import print_function
import tensorflow as tf
import sys
assert len(sys.argv) == 2, \
"USAGE: {} <file_glob>".format(sys.argv[0])
tf.enable_eager_execution()
input_pattern = sys.argv[1]
# Expand glob if there is one
input_files = tf.io.gfile.glob(input_pattern)
# Create the dataset
data_set = tf.data.TFRecordDataset(input_files)
# Count the records
records_n = sum(1 for record in data_set)
print("records_n = {}".format(records_n))
于 2019-07-25T18:22:06.270 回答
0
由于 tf.enable_eager_execution() 不再有效,请使用:
tf.compat.v1.enable_eager_execution
sum(1 for _ in tf.data.TFRecordDataset(FILENAMES))
于 2021-05-23T20:09:40.063 回答