1

我知道LSTM Followed by Mean Pooling有一个类似的主题,但那是关于 Keras 的,我在纯 TensorFlow 中工作。

我有一个 LSTM 网络,其中重复由以下方式处理:

outputs, final_state = tf.nn.dynamic_rnn(cell,
                                         embed,
                                         sequence_length=seq_lengths,
                                         initial_state=initial_state)

我为每个样本传递正确的序列长度(用零填充)。在任何情况下,输出都包含不相关的输出,因为根据序列长度,某些样本比其他样本产生更长的输出。

现在我正在通过以下方法提取最后一个相关输出:

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.reduce_mean(tf.gather_nd(data, indices), axis=0)

sequence_length - 1我作为索引传递的地方。关于最后一个主题,我想选择所有相关输出,然后选择平均池,而不仅仅是最后一个。

现在,我尝试将嵌套列表作为索引传递给extract_axis_1tf.stack不接受。

有什么解决方法吗?

4

1 回答 1

0

您可以利用函数的weight参数tf.contrib.seq2seq.sequence_loss

从文档中:

weights: 形状为 [ batch_size, sequence_length] 和 dtype的张量float。权重构成序列中每个预测的权重。当weights用作掩码时,将所有有效时间步设置为 1,将所有填充时间步设置为 0,例如由 . 返回的掩码tf.sequence_mask

您需要计算一个二进制掩码来区分有效输出和无效输出。然后你可以把这个掩码提供给weights损失函数的参数(可能你会想要使用这样的损失);该函数在计算损失时不会考虑权重为 0 的输出。

如果您不能/不需要使用序列丢失,您可以手动执行完全相同的操作。您计算一个二进制掩码,然后将您的输出乘以该掩码,并将这些作为输入提供给您的全连接层。

于 2017-09-04T07:58:42.570 回答