问题
在 Tensorflow 中,我经常在第一个训练阶段遇到 OOM 错误。然而,网络的庞大性质导致第一个 epoch 需要大约一个小时,对于快速测试新的超参数来说非常长。
理想情况下,我希望能够对迭代器进行排序,这样我就可以get_next()
在最大的批次上运行一次。
我怎样才能做到这一点?或者也许有更好的方法来实现早期失败?
迭代器的格式为:(source, tgt_in, tgt_out, key_weights, source_len, target_len)
我希望按目标长度排序。它在返回之前被填充和批处理。
数据集是一个句子列表,具有相似的长度。我想在迭代器中找到最大的批次并只运行它。
一些代码
如果初始化程序没有每次都对迭代器进行洗牌,则下面的代码将起作用,从而破坏获得的有关最大批次位置的信息。我不太确定如何修改它——一旦使用 读取批次的长度get_next()
,它就已经被“弹出”并且不能再用作模型的输入。
def verify_hparams():
train_sess.run(train_model.iterator.initializer)
max_index = -1
max_len = 0
for batch in itertools.count():
try:
batch_len = np.amax(train_sess.run(train_model.iterator.get_next()[-1]))
if batch_len > max_len:
max_len = batch_len
max_index = batch
except tf.errors.OutOfRangeError:
num_batches = batch + 1
break
for batch in range(-1, num_batches-1):
try:
if batch is max_index:
_, _ = loaded_train_model.train(train_sess)
else:
train_sess.run(train_model.iterator.get_next())
except tf.errors.OutOfRangeError:
break
return num_batches