我正在使用 tensorflow 插件在 tensorflow 2 中构建一个编码器-解码器模型。对于预测,我尝试使用tfa.seq2seq.BeamSearchDecoder
该类。不幸的是,我在输出中得到了一个没有时间步长的张量。
final_outputs: (64, 5, None)
outputs: (64, None)
该错误很可能在我的解码器类的调用方法中。
def call(self, dec_input, enc_output, enc_hidden, start_token=1, end_token=2, training=False):
start_tokens = tf.fill([self.batch_sz], start_token)
enc_out = tfa.seq2seq.tile_batch(enc_output, multiplier=self.beam_width) # beam_with * [batch_size, max_length_input, rnn_units]
self.attention_mechanism.setup_memory(enc_out)
# set decoder_inital_state which is an AttentionWrapperState considering beam_width
hidden_state = tfa.seq2seq.tile_batch(enc_hidden, multiplier=self.beam_width)
decoder_initial_state = self.rnn_cell.get_initial_state(batch_size=self.beam_width*self.batch_sz, dtype=tf.float32)
decoder_initial_state = decoder_initial_state.clone(
cell_state=hidden_state)
# Instantiate BeamSearchDecoder
decoder_instance = tfa.seq2seq.BeamSearchDecoder(self.rnn_cell, beam_width=self.beam_width, output_layer=self.fc)
decoder_embedding_matrix = self.embd_layer.variables[0]
print("decoder_embedding_matrix: ", decoder_embedding_matrix.shape)
# The BeamSearchDecoder object's call() function takes care of everything.
outputs, _, _ = decoder_instance(
decoder_embedding_matrix, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state)
final_outputs = tf.transpose(outputs.predicted_ids, perm=(0, 2, 1))
outputs = final_outputs[:,0,:] # [batch, length]
return outputs
如果您需要任何进一步的信息或对如何解决此问题有想法,请告诉我。
谢谢,祝你有美好的一天,
索伦