1

我有一个很长的时间序列,我想输入 LSTM 以进行每帧分类。

我的数据是按帧标记的,我知道一些罕见的事件发生后会严重影响分类。

因此,我必须输入整个序列才能获得有意义的预测。

众所周知,仅将非常长的序列输入 LSTM 是次优的,因为梯度会像正常的 RNN 一样消失或爆炸。


我想使用一种简单的技术将序列切割成更短(比如 100 长)的序列,并在每个序列上运行 LSTM,然后将最终的 LSTM 隐藏和单元状态作为下一个前向传递的开始隐藏和单元状态传递.

是我发现的一个例子。在那里它被称为“通过时间截断的反向传播”。我无法为我做同样的工作。


我在 Pytorch 闪电中的尝试(去掉了不相关的部分):

def __init__(self, config, n_classes, datamodule):
    ...
    self._criterion = nn.CrossEntropyLoss(
        reduction='mean',
    )

    num_layers = 1
    hidden_size = 50
    batch_size=1

    self._lstm1 = nn.LSTM(input_size=len(self._in_features), hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
    self._log_probs = nn.Linear(hidden_size, self._n_predicted_classes)
    self._last_h_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)
    self._last_c_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)

def training_step(self, batch, batch_index):
    orig_batch, label_batch = batch
    n_labels_in_batch = np.prod(label_batch.shape)
    lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch, (self._last_h_n, self._last_c_n))
    log_probs = self._log_probs(lstm_out)
    loss = self._criterion(log_probs.view(n_labels_in_batch, -1), label_batch.view(n_labels_in_batch))

    return loss

运行此代码会出现以下错误:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

如果我添加也会发生同样的情况

def on_after_backward(self) -> None:
    self._last_h_n.detach()
    self._last_c_n.detach()

如果我使用该错误不会发生

lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch,)

但显然这是没有用的,因为当前帧批次的输出不会转发到下一个。


是什么导致了这个错误?我认为分离输出应该足够了h_nc_n

如何将前一个帧批次的输出传递给下一个帧,并让火炬分别反向传播每个帧批次?

4

1 回答 1

1

显然,我错过了_以下内容detach()

使用

def on_after_backward(self) -> None:
    self._last_h_n.detach_()
    self._last_c_n.detach_()

作品。


问题是self._last_h_n.detach()没有更新对由 detach() 分配的新内存的引用,因此该图仍在取消引用 backprop 所经过的旧变量。 参考答案解决了这个问题H = H.detach()

更清洁(并且可能更快)是self._last_h_n.detach_()执行操作的地方。

于 2021-04-01T11:40:14.037 回答