0

我遇到了几个使用 RNN 对 MNIST 数字进行分类的示例,使用 sequence_length=1 初始化隐藏状态的原因是什么?如果您正在对视频帧预测进行 1 步预测,您将如何初始化它?

def init_hidden(self, x, device=None): # input 4D tensor: (batch size, channels, width, height)
    # initialize the hidden and cell state to zero
    # vectors:(number of layer, sequence length, number of hidden nodes)
    if (self.bidirectional):
        h0 = torch.zeros(2*self.n_layers, 1, self.n_hidden)
    else:
        h0 = torch.zeros(self.n_layers,  1, self.n_hidden)

    if device is not None:
        h0 = h0.to(device)
    self.hidden = h0

输入通常表示为

inputs = inputs.view(batch_size*image_height, 1, image_width)

在上面的示例中,图像是按列传递的吗?是否有另一种方法来表示 RNN 中的输入图像?它与如何初始化隐藏状态有什么关系?

4

1 回答 1

0

初始化隐藏状态时,第二维其实不是sequence-length,而是batch size:

hidden = torch.zeros(layers, batch_size, hidden_nodes)

对于 MNIST rnn,我会说输入形状是 28x1(一行的形状),序列长度也是 28(有 28 行)。

input_size = 28
hidden_nodes = 128 # for example
layers = 2 # for example
dropout = 0.35

rnn = nn.RNN(input_size=input_size, hidden_size=hidden_nodes, num_layers=layers, dropout=dropout, batch_first=True)

现在让我们初始化隐藏状态:

hidden = torch.zeros(layers, batch_size, hidden_nodes)

您不必告诉隐藏状态序列有多长,也不必告诉序列的元素有多长。隐藏层应该有多大。

因此,您可以将 mnist 的序列长度不能为 1,它必须为 28,因为有 28 行。序列大小为 1 的 RNN 没有任何意义,因为序列只有在具有超过 1 个元素时才是序列。

编辑以回答评论中的问题:

会的(batch_size, 28, 28)。就像在没有通道维度的情况下将图像传递给 cnn 一样。前 28 位代表序列长度。第二个 28 表示一个序列有多长。也许另一个例子更清楚:如果你有一个 RNN,它(无论出于何种原因)4 个字母作为输入,并且每个字母都是单热编码的(a例如,字母将是一个长度为 26 的向量,长度为字母表,其中每个元素为零,但第一个元素为 1)输入维度如下所示:(batch_size, 4, 26),batch_size,序列长度为 4(4 个字母),序列中的每个元素/字母的长度为 28(一个-热编码字母表)。

于 2020-10-31T10:16:36.750 回答