2

chainer 文档 RNN 教程在此页面中有不正确的代码: https ://docs.chainer.org/en/stable/tutorial/recurrentnet.html

def update_bptt(updater):
    loss = 0
    for i in range(35):
        batch = train_iter.__next__()
        x, t = chainer.dataset.concat_examples(batch)
        loss += model(chainer.Variable(x), chainer.Variable(t))

    model.cleargrads()
    loss.backward()
    loss.unchain_backward()  # truncate
    optimizer.update()

updater = training.StandardUpdater(train_iter, optimizer, **update_bptt**)

training.StandardUpdater 第三个参数是converter=concat_example,不是更新函数。如何准确地使用 trainer 编写 BPTT?

4

0 回答 0