(问题似乎很老了,但我偶然发现它并想分享我对这个问题的解决方案)
你基本上会像在 PyTorch 中那样做。不幸的是,StandardUpdater 既没有支持它的超参数,也没有“小批量更新”的实现。但这是我的实现,我是如何做到的(基本上就像您在问题中提到的那样:从 StandardUpdater 继承并重新实现 update_core 方法):
from chainer.training import StandardUpdater
from chainer.dataset import convert
class MiniBatchUpdater(StandardUpdater):
"""
The iterator outputs batches in mini-batch sizes. This updater
cummulates the gradients of these mini-batches until the
update_size is reached. Then a parameter update is performed
"""
def __init__(self, update_size=32, *args, **kwargs):
super(MiniBatchUpdater, self).__init__(*args, **kwargs)
self.update_size = update_size
self.iteration_counter = 0
def update_core(self):
optimizer = self.get_optimizer('main')
loss_func = self.loss_func or optimizer.target
it = self.get_iterator('main')
batch = it.next()
data = convert._call_converter(self.converter, batch, self.device)
use_cleargrads = getattr(optimizer, '_use_cleargrads', True)
if use_cleargrads and self.iteration_counter == 0:
optimizer.target.cleargrads()
self.iteration_counter += it.batch_size
loss = loss_func(*data)
loss.backward()
if self.iteration_counter >= self.update_size:
self.iteration_counter = 0
optimizer.update()
实现已经很老了(我认为是 chainer 4 或 5),但我也使用 chainer 7.8 为我工作。可以更新一些行以匹配 update_core 方法的较新实现,但正如我所说,它适用于我。希望它有所帮助;)