我知道LSTM Followed by Mean Pooling有一个类似的主题,但那是关于 Keras 的,我在纯 TensorFlow 中工作。
我有一个 LSTM 网络,其中重复由以下方式处理:
outputs, final_state = tf.nn.dynamic_rnn(cell,
embed,
sequence_length=seq_lengths,
initial_state=initial_state)
我为每个样本传递正确的序列长度(用零填充)。在任何情况下,输出都包含不相关的输出,因为根据序列长度,某些样本比其他样本产生更长的输出。
现在我正在通过以下方法提取最后一个相关输出:
def extract_axis_1(data, ind):
"""
Get specified elements along the first axis of tensor.
:param data: Tensorflow tensor that will be subsetted.
:param ind: Indices to take (one for each element along axis 0 of data).
:return: Subsetted tensor.
"""
batch_range = tf.range(tf.shape(data)[0])
indices = tf.stack([batch_range, ind], axis=1)
res = tf.reduce_mean(tf.gather_nd(data, indices), axis=0)
sequence_length - 1
我作为索引传递的地方。关于最后一个主题,我想选择所有相关输出,然后选择平均池,而不仅仅是最后一个。
现在,我尝试将嵌套列表作为索引传递给extract_axis_1
但tf.stack
不接受。
有什么解决方法吗?