我正在尝试实现一个 seq2seq 模型,在 Tensorflow (1.6.0) 中使用 bidirectional_dynamic_decode、Attention 和 BeamSearchDecoder。(我试图只复制相关代码,以保持简单)
# encoder
def make_lstm(rnn_size, keep_prob):
lstm = tf.nn.rnn_cell.LSTMCell(rnn_size, initializer =
tf.truncated_normal_initializer(mean = 0.0, stddev = 1.0))
lstm_dropout = tf.nn.rnn_cell.DropoutWrapper(lstm, input_keep_prob
= keep_prob)
return lstm_dropout
cell_fw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob)
for _ in range(n_layers)])
cell_bw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob)
for _ in range(n_layers)])
enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length=sequence_length,
dtype=tf.float32)
enc_output = tf.concat(enc_output,2)
dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob)
for _ in range(num_layers)])
output_layer = Dense(vocab_size, kernel_initializer =
tf.truncated_normal_initializer(mean = 0.0,
stddev=0.1))
# training_decoding_layer
with tf.variable_scope('decode'):
....
# inference_decoding_layer
with tf.variable_scope('decode', reuse = True):
beam_width = 10
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(enc_output,
multiplier=beam_width)
tiled_encoder_final_state =
tf.contrib.seq2seq.tile_batch(enc_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(text_length,
multiplier=beam_width)
start_tokens = tf.tile(tf.constant([word2ind['<GO>']], dtype =
tf.int32), [batch_size], name = 'start_tokens')
attn_mech = tf.contrib.seq2seq.BahdanauAttention( num_units =
rnn_size,
memory =
tiled_encoder_outputs,
memory_sequence_length=tiled_sequence_length,
normalize=True )
beam_dec_cell = tf.contrib.seq2seq.AttentionWrapper(dec_cell,
attn_mech, rnn_size)
beam_initial_state = beam_dec_cell.zero_state(batch_size =
batch_size*beam_width , dtype = tf.float32)
beam_initial_state =
beam_initial_state.clone(cell_state=tiled_encoder_final_state)
但是,当我尝试将编码器的最后状态克隆到上图中的“beam_initial_state”变量时,出现以下错误:
ValueError: The two structures don't have the same number of elements.
First structure (6 elements): AttentionWrapperState(cell_state= .
(LSTMStateTuple(c=<tf.Tensor
'decode_1/AttentionWrapperZeroState/checked_cell_state:0' shape=(640,
256) dtype=float32>, h=<tf.Tensor
'decode_1/AttentionWrapperZeroState/checked_cell_state_1:0' shape=(640,
256) dtype=float32>),), attention=<tf.Tensor
'decode_1/AttentionWrapperZeroState/zeros_1:0' shape=(640, 256)
dtype=float32>, time=<tf.Tensor
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>,
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0'
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state=
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640,
?) dtype=float32>)
Second structure (8 elements): AttentionWrapperState(cell_state= .
((LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape:0' shape=
(?, 256) dtype=float32>, h=<tf.Tensor
'decode_1/tile_batch_1/Reshape_1:0' shape=(?, 256) dtype=float32>),),
(LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape_2:0' shape=
(?, 256) dtype=float32>, h=<tf.Tensor
'decode_1/tile_batch_1/Reshape_3:0' shape=(?, 256) dtype=float32>),)),
attention=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_1:0'
shape=(640, 256) dtype=float32>, time=<tf.Tensor
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>,
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0'
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state= .
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640,
?) dtype=float32>)
有人给点建议吗?提前非常感谢。