0

如何使用 Chainer 中的 Standard Updater 类大幅增加小批量的数量?

在 PyTorch 的情况下,我可以大幅增加 mini-batch 的数量。

  • 每次执行 loss.backward() 。
  • 每 3 次执行 optimizer.step() / optimizer.zero_grad() 一次。这有效地大幅增加了 mini-batch 的数量。

问题 1. Chainer 的情况下,是否可以大幅增加 mini-batch 的数量?

  • 每次执行 loss.backward() 。
  • 每 3 次执行一次 net.cleargrads()/optimizer.update()。这可以大幅增加小批量的数量吗?

问题 2. 事实上,我使用的是 StandardUpdater 类。是否可以使用任何超参数大幅增加小批量的数量?或者我应该让我的类继承自 StandardUpdater 类并更改上面的实现?

如果问题已经被问到,我很抱歉。

我希望任何建议。

4

1 回答 1

0

(问题似乎很老了,但我偶然发现它并想分享我对这个问题的解决方案)

你基本上会像在 PyTorch 中那样做。不幸的是,StandardUpdater 既没有支持它的超参数,也没有“小批量更新”的实现。但这是我的实现,我是如何做到的(基本上就像您在问题中提到的那样:从 StandardUpdater 继承并重新实现 update_core 方法):

from chainer.training import StandardUpdater
from chainer.dataset import convert

class MiniBatchUpdater(StandardUpdater):
    """
        The iterator outputs batches in mini-batch sizes. This updater
        cummulates the gradients of these mini-batches until the
        update_size is reached. Then a parameter update is performed
    """
    def __init__(self, update_size=32, *args, **kwargs):
        super(MiniBatchUpdater, self).__init__(*args, **kwargs)
        self.update_size = update_size
        self.iteration_counter = 0

    def update_core(self):
        optimizer = self.get_optimizer('main')
        loss_func = self.loss_func or optimizer.target
        it = self.get_iterator('main')

        batch = it.next()
        data = convert._call_converter(self.converter, batch, self.device)

        use_cleargrads = getattr(optimizer, '_use_cleargrads', True)
        if use_cleargrads and self.iteration_counter == 0:
            optimizer.target.cleargrads()

        self.iteration_counter += it.batch_size
        loss = loss_func(*data)
        loss.backward()

        if self.iteration_counter >= self.update_size:
            self.iteration_counter = 0
            optimizer.update()

实现已经很老了(我认为是 chainer 4 或 5),但我也使用 chainer 7.8 为我工作。可以更新一些行以匹配 update_core 方法的较新实现,但正如我所说,它适用于我。希望它有所帮助;)

于 2021-08-10T07:22:25.683 回答