我正在构建一个用于语言识别的全状态 LSTM。在有状态的情况下,我可以用较小的文件训练网络,一个新的批次就像讨论中的下一个句子。然而,为了使网络得到适当的训练,我需要在一些批次之间重置 LSTM 的隐藏状态。
我正在使用一个变量来存储 LSTM 的 hidden_state 以提高性能:
with tf.variable_scope('Hidden_state'):
hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size],
tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
# Arrange it to a tuple of LSTMStateTuple as needed
l = tf.unstack(hidden_state, axis=0)
rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1])
for idx in range(self.num_layers)])
# Build the RNN
with tf.name_scope('LSTM'):
rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths,
initial_state=rnn_tuple_state, time_major=True)
现在我对如何重置隐藏状态感到困惑。我尝试了两种解决方案,但它不起作用:
第一个解决方案
使用以下命令重置“hidden_state”变量:
rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state))
它确实有效,我认为这是因为在运行 rnn_state_zero_op 操作后,unstack 和 tuple 构造没有“重新播放”到图中。
第二种解决方案
在Tensorflow 中针对 RNN 的 LSTMStateTuple vs cell.zero_state()之后,我尝试使用以下命令重置单元状态:
rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32)
它似乎也不起作用。
问题
我想到了另一个解决方案,但它充其量只是猜测:我没有保留 tf.nn.dynamic_rnn 返回的状态,我已经想到了,但是我得到了一个元组,但我找不到构建一个op 重置元组。
在这一点上,我不得不承认我不太了解 tensorflow 的内部工作,以及是否有可能做我想做的事情。有正确的方法吗?
谢谢 !