0

我正在使用 tensorflow 插件在 tensorflow 2 中构建一个编码器-解码器模型。对于预测,我尝试使用tfa.seq2seq.BeamSearchDecoder该类。不幸的是,我在输出中得到了一个没有时间步长的张量。

final_outputs:  (64, 5, None)
outputs:  (64, None)

该错误很可能在我的解码器类的调用方法中。

def call(self, dec_input, enc_output, enc_hidden, start_token=1, end_token=2, training=False):
     start_tokens = tf.fill([self.batch_sz], start_token)

     enc_out = tfa.seq2seq.tile_batch(enc_output, multiplier=self.beam_width) # beam_with * [batch_size, max_length_input, rnn_units]
     self.attention_mechanism.setup_memory(enc_out)

     # set decoder_inital_state which is an AttentionWrapperState considering beam_width
     hidden_state = tfa.seq2seq.tile_batch(enc_hidden, multiplier=self.beam_width)
     decoder_initial_state = self.rnn_cell.get_initial_state(batch_size=self.beam_width*self.batch_sz, dtype=tf.float32)
     decoder_initial_state = decoder_initial_state.clone(
     cell_state=hidden_state)

     # Instantiate BeamSearchDecoder
     decoder_instance = tfa.seq2seq.BeamSearchDecoder(self.rnn_cell, beam_width=self.beam_width, output_layer=self.fc)
     decoder_embedding_matrix = self.embd_layer.variables[0]
     print("decoder_embedding_matrix: ", decoder_embedding_matrix.shape)

     # The BeamSearchDecoder object's call() function takes care of everything.
     outputs, _, _ = decoder_instance(
     decoder_embedding_matrix, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state)
            
     final_outputs = tf.transpose(outputs.predicted_ids, perm=(0, 2, 1))
            
     outputs = final_outputs[:,0,:]  # [batch, length]
     return outputs

如果您需要任何进一步的信息或对如何解决此问题有想法,请告诉我。

谢谢,祝你有美好的一天,

索伦

4

0 回答 0