5

我正在对具有不同长度的批量序列训练 LSTM 单元。tf.nn.rnn有一个非常方便的参数,sequence_length但是调用之后,我不知道如何选择批处理中每个项目的最后一个时间步对应的输出行。

我的代码基本上如下:

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)

lstm_outputs是每个时间步的 LSTM 输出列表。但是,我的批次中的每个项目都有不同的长度,因此我想创建一个张量,其中包含对我的批次中的每个项目有效的最后一个 LSTM 输出。

如果我可以使用 numpy 索引,我会做这样的事情:

all_outputs = tf.pack(lstm_outputs)
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]

但事实证明,当时开始 tensorflow 不支持它(我知道功能请求)。

那么,我怎样才能得到这些值呢?

4

3 回答 3

5

danijar 在我在问题中链接的功能请求页面上发布了一个更可接受的解决方法。它不需要评估张量,这是一个很大的优势。

我让它与 tensorflow 0.8 一起工作。这是代码:

def extract_last_relevant(outputs, length):
    """
    Args:
        outputs: [Tensor(batch_size, output_neurons)]: A list containing the output
            activations of each in the batch for each time step as returned by
            tensorflow.models.rnn.rnn.
        length: Tensor(batch_size): The used sequence length of each example in the
            batch with all later time steps being zeros. Should be of type tf.int32.

    Returns:
        Tensor(batch_size, output_neurons): The last relevant output activation for
            each example in the batch.
    """
    output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2])
    # Query shape.
    batch_size = tf.shape(output)[0]
    max_length = int(output.get_shape()[1])
    num_neurons = int(output.get_shape()[2])
    # Index into flattened array as a workaround.
    index = tf.range(0, batch_size) * max_length + (length - 1)
    flat = tf.reshape(output, [-1, num_neurons])
    relevant = tf.gather(flat, index)
    return relevant
于 2016-05-03T10:13:49.707 回答
2

这不是最好的解决方案,但您可以评估您的输出,然后使用 numpy 索引来获取结果并从中创建一个张量变量?在 tensorflow 获得此功能之前,它可能会作为一个权宜之计。例如

all_outputs = session.run(lstm_outputs, feed_dict={'your inputs'})
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
use_this_as_an_input_to_new_tensorflow_op = tf.constant(last_outputs)
于 2016-03-07T14:17:17.710 回答
1

如果您只对最后一个有效输出感兴趣,您可以通过返回的状态来检索它,tf.nn.rnn()因为它始终是一个元组 (c, h),其中 c 是最后一个状态,h 是最后一个输出。当状态为 a 时,LSTMStateTuple您可以使用以下代码段(在 tensorflow 0.12 中工作):

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
last_output = state[1]
于 2017-01-15T13:28:31.497 回答