1

问题

在 Tensorflow 中,我经常在第一个训练阶段遇到 OOM 错误。然而,网络的庞大性质导致第一个 epoch 需要大约一个小时,对于快速测试新的超参数来说非常长。

理想情况下,我希望能够对迭代器进行排序,这样我就可以get_next()在最大的批次上运行一次。

我怎样才能做到这一点?或者也许有更好的方法来实现早期失败?

迭代器的格式为:(source, tgt_in, tgt_out, key_weights, source_len, target_len)我希望按目标长度排序。它在返回之前被填充和批处理。

数据集是一个句子列表,具有相似的长度。我想在迭代器中找到最大的批次并只运行它。

一些代码

如果初始化程序没有每次都对迭代器进行洗牌,则下面的代码将起作用,从而破坏获得的有关最大批次位置的信息。我不太确定如何修改它——一旦使用 读取批次的长度get_next(),它就已经被“弹出”并且不能再用作模型的输入。

def verify_hparams():
    train_sess.run(train_model.iterator.initializer)
    max_index = -1
    max_len = 0
    for batch in itertools.count():
        try:
            batch_len = np.amax(train_sess.run(train_model.iterator.get_next()[-1]))
            if batch_len > max_len:
                max_len = batch_len
                max_index = batch

        except tf.errors.OutOfRangeError:
            num_batches = batch + 1
            break

    for batch in range(-1, num_batches-1):
        try:
            if batch is max_index:
                _, _ = loaded_train_model.train(train_sess)
            else:
                train_sess.run(train_model.iterator.get_next())

        except tf.errors.OutOfRangeError:
            break

    return num_batches
4

1 回答 1

1

您需要的是“窥视”操作。大多数语言都有迭代器,可以让您查看是否有更多数据(例如iterator.hasNext())。但是您要求的功能本质上类似于iterator.sizeOfNext(). 据我所知,张量流迭代器没有这样的功能

此外,不太可能添加此类功能,因为我可以想象有些生成器无法提供此类功能,因此添加此功能会破坏向后兼容性。

于 2018-02-02T12:06:03.090 回答