我正在对具有不同长度的批量序列训练 LSTM 单元。tf.nn.rnn
有一个非常方便的参数,sequence_length
但是调用之后,我不知道如何选择批处理中每个项目的最后一个时间步对应的输出行。
我的代码基本上如下:
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
lstm_outputs
是每个时间步的 LSTM 输出列表。但是,我的批次中的每个项目都有不同的长度,因此我想创建一个张量,其中包含对我的批次中的每个项目有效的最后一个 LSTM 输出。
如果我可以使用 numpy 索引,我会做这样的事情:
all_outputs = tf.pack(lstm_outputs)
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
但事实证明,当时开始 tensorflow 不支持它(我知道功能请求)。
那么,我怎样才能得到这些值呢?