使用initial_state
的参数tf.nn.dynamic_rnn
:
initial_state
:(可选)RNN 的初始状态。如果
cell.state_size
是整数,则这必须是适当类型和形状的张量[batch_size, cell.state_size]
。如果cell.state_siz
e 是一个元组,这应该是一个具有形状的张量元组[batch_size, s] for s in cell.state_size
。
文档中的一个改编示例:
# create a GRUCell
cell = tf.nn.rnn_cell.GRUCell(cell_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
另请注意,尽管initial_state
不是占位符,但您也可以将值提供给它。因此,如果希望在一个纪元内保留状态,但在纪元开始时从零开始,您可以这样做:
# Compute the zero state array of the right shape once
zero_state = sess.run(initial_state)
# Start with a zero vector and update it
cur_state = zero_state
for batch in get_batches():
cur_state, _ = sess.run([state, ...], feed_dict={initial_state=cur_state, ...})