我正在尝试使用 Amazon 评论数据集进行文本摘要。我在构建模型时遇到了错误。
AttributeError: ‘LSTMStateTuple’ object has no attribute ‘get_shape’
我知道我错过了一些东西。但无法弄清楚它是什么。我是张量流的新手。
我想问题出在我的编码层上,我试图连接输出的方式。
def encoding_layer(embeded_rnn_input,rnn_size,keep_prob,num_layers,batch_size,source_sequence_length):
cell_fw = tf.contrib.rnn.MultiRNNCell([get_lstm(rnn_size,keep_prob) for _ in range(num_layers)])
cell_bw = tf.contrib.rnn.MultiRNNCell([get_lstm(rnn_size,keep_prob) for _ in range(num_layers)])
encoder_outputs,encoder_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw,cell_bw=cell_bw,inputs=embeded_rnn_input,
sequence_length=source_sequence_length,dtype=tf.float32)
encoder_outputs = tf.concat(encoder_outputs, 2)
return encoder_outputs,encoder_states
编辑:
删除了笔记本的链接,因为它将来会改变。为错误添加堆栈跟踪。
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-71-85ee67bc88e5> in <module>()
9 # Create the training and inference logits
10 training_logits, inference_logits = seq2seq_model(input_,target,embeding_matrix,vocab_to_int,source_seq_length,target_seq_length,
---> 11 max_target_seq_length,rnn_size,keep_probability,num_layers,batch_size)
12
13 # Create tensors for the training logits and inference logits
<ipython-input-70-5ad1bf459bd7> in seq2seq_model(source_input, target_input, embeding_matrix, vocab_to_int, source_sequence_length, target_sequence_length, max_target_length, rnn_size, keep_prob, num_layers, batch_size)
15 training_logits, inference_logits = decoding_layer(target_input,encoder_states,embedings,
16 vocab_to_int,rnn_size,target_sequence_length,
---> 17 max_target_length,batch_size,num_layers)
18
19 return training_logits, inference_logits
<ipython-input-69-c2b4542605d2> in decoding_layer(target_inputs, encoder_state, embedding, vocab_to_int, rnn_size, target_sequence_length, max_target_length, batch_size, num_layers)
12
13 training_logits = training_decoder(embed,decoder_cell,encoder_state,output_layer,
---> 14 target_sequence_length,max_target_length)
15
16
<ipython-input-67-91fcb3f89090> in training_decoder(dec_embed_input, decoder_cell, encoder_state, output_layer, target_sequence_length, max_target_length)
8
9 final_outputs, final_state,_ = tf.contrib.seq2seq.dynamic_decode(decoder=decoder,impute_finished=True,
---> 10 maximum_iterations=max_target_length)
11
12 return final_outputs
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py in dynamic_decode(decoder, output_time_major, impute_finished, maximum_iterations, parallel_iterations, swap_memory, scope)
284 ],
285 parallel_iterations=parallel_iterations,
--> 286 swap_memory=swap_memory)
287
288 final_outputs_ta = res[1]
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name)
2773 context = WhileContext(parallel_iterations, back_prop, swap_memory, name)
2774 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context)
-> 2775 result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
2776 return result
2777
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py in BuildLoop(self, pred, body, loop_vars, shape_invariants)
2602 self.Enter()
2603 original_body_result, exit_vars = self._BuildLoop(
-> 2604 pred, body, original_loop_vars, loop_vars, shape_invariants)
2605 finally:
2606 self.Exit()
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/control_flow_ops.py in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants)
2552 structure=original_loop_vars,
2553 flat_sequence=vars_for_body_with_tensor_arrays)
-> 2554 body_result = body(*packed_vars_for_body)
2555 if not nest.is_sequence(body_result):
2556 body_result = [body_result]
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py in body(time, outputs_ta, state, inputs, finished, sequence_lengths)
232 """
233 (next_outputs, decoder_state, next_inputs,
--> 234 decoder_finished) = decoder.step(time, inputs, state)
235 next_finished = math_ops.logical_or(decoder_finished, finished)
236 if maximum_iterations is not None:
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py in step(self, time, inputs, state, name)
137 """
138 with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
--> 139 cell_outputs, cell_state = self._cell(inputs, state)
140 if self._output_layer is not None:
141 cell_outputs = self._output_layer(cell_outputs)
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py in __call__(self, inputs, state, scope)
178 with vs.variable_scope(vs.get_variable_scope(),
179 custom_getter=self._rnn_get_variable):
--> 180 return super(RNNCell, self).__call__(inputs, state)
181
182 def _rnn_get_variable(self, getter, *args, **kwargs):
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/layers/base.py in __call__(self, inputs, *args, **kwargs)
448 # Check input assumptions set after layer building, e.g. input shape.
449 self._assert_input_compatibility(inputs)
--> 450 outputs = self.call(inputs, *args, **kwargs)
451
452 # Apply activity regularization.
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py in call(self, inputs, state)
936 [-1, cell.state_size])
937 cur_state_pos += cell.state_size
--> 938 cur_inp, new_state = cell(cur_inp, cur_state)
939 new_states.append(new_state)
940
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py in __call__(self, inputs, state, scope)
772 self._recurrent_input_noise,
773 self._input_keep_prob)
--> 774 output, new_state = self._cell(inputs, state, scope)
775 if _should_dropout(self._state_keep_prob):
776 new_state = self._dropout(new_state, "state",
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py in __call__(self, inputs, state, scope)
178 with vs.variable_scope(vs.get_variable_scope(),
179 custom_getter=self._rnn_get_variable):
--> 180 return super(RNNCell, self).__call__(inputs, state)
181
182 def _rnn_get_variable(self, getter, *args, **kwargs):
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/layers/base.py in __call__(self, inputs, *args, **kwargs)
448 # Check input assumptions set after layer building, e.g. input shape.
449 self._assert_input_compatibility(inputs)
--> 450 outputs = self.call(inputs, *args, **kwargs)
451
452 # Apply activity regularization.
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py in call(self, inputs, state)
399 c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
400
--> 401 concat = _linear([inputs, h], 4 * self._num_units, True)
402
403 # i = input_gate, j = new_input, f = forget_gate, o = output_gate
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py in _linear(args, output_size, bias, bias_initializer, kernel_initializer)
1019 # Calculate the total size of arguments on dimension 1.
1020 total_arg_size = 0
-> 1021 shapes = [a.get_shape() for a in args]
1022 for shape in shapes:
1023 if shape.ndims != 2:
~/anaconda2/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py in <listcomp>(.0)
1019 # Calculate the total size of arguments on dimension 1.
1020 total_arg_size = 0
-> 1021 shapes = [a.get_shape() for a in args]
1022 for shape in shapes:
1023 if shape.ndims != 2:
AttributeError: 'LSTMStateTuple' object has no attribute 'get_shape'