我开始使用新的数据集 API,我想做的一件事没有在文档中描述(https://www.tensorflow.org/programmers_guide/datasets#training_workflows)
我的数据适合内存,所以我想将它加载到 tensorflow 中以提高训练效率,为此我现在看到了 2 种方法:
一种是直接在图中加载数据,如下所示:
dataset = tf.contrib.data.Dataset.from_tensor_slices((X, Y))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# loop on epochs
for _ in range(5):
# Initialize an iterator over the training dataset.
sess.run(iterator.initializer)
# loop over all the batch
for _ in range(1000):
s = time.time()
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
print("Finish epoch")
另一种是将数据加载到占位符中,这样数据就不会保存在图表中:
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.contrib.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# loop on epochs
for _ in range(5):
# Initialize an iterator over the training dataset.
sess.run(iterator.initializer, feed_dict={features_placeholder: X, labels_placeholder: Y})
# loop over all the batch
for _ in range(1000):
s = time.time()
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
print("Finish epoch")
第二个是我认为最好节省内存,但我不想在每个时期都提供数据。真的是白白损失了性能。
有没有办法用占位符只初始化一次迭代器?
像这样的东西:
sess.run(iterator.initializer, feed_dict={features_placeholder: X, labels_placeholder: Y})
# loop on epochs
for _ in range(5):
# Initialize an iterator over the training dataset.
sess.run(iterator.initializer)
# loop over all the batch
for _ in range(1000):
s = time.time()
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
print("Finish epoch")
这样我们可以保持第一个解决方案的性能并像第二个解决方案一样节省内存。
笔记:
一种解决方案是使用方法定义 epoch 的数量,
dataset.repeat()
但使用它我们有点松散地跟踪我们在训练中的位置。我想在每个时期(一次遍历所有数据)之后检查损失的演变。