0

我认为我使用tensorflow_addons.seq2seq2.BeamSearchDecoder不正确。我收到了关于不良张量形状的各种神秘信息。我想这个问题与它们的平铺方式有关initial_statestart_tokens或者可能与未指定的批量大小有关。文档没有正确指定这些参数应该如何工作。

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras

vocab_size = 1000
embedding_size = 100

encoder_inputs = keras.layers.Input(shape=[None], dtype=tf.int32)
decoder_inputs = keras.layers.Input(shape=[None], dtype=tf.int32)
decoder_input_lengths  = keras.layers.Input(shape=[], dtype=tf.int32)

embedding_layer = keras.layers.Embedding(vocab_size, embedding_size)
encoder_input_embeddings = embedding_layer(encoder_inputs)
decoder_input_embeddings = embedding_layer(decoder_inputs)

encoder = keras.layers.LSTM(512, return_state=True)
_, encoder_final_state_h, encoder_final_state_c = encoder(encoder_input_embeddings)
encoder_final_state = [encoder_final_state_h, encoder_final_state_c]

decoder_cell = keras.layers.LSTMCell(512)
output_layer = keras.layers.Dense(vocab_size)
beam_width = 3
decoder = tfa.seq2seq.BeamSearchDecoder(cell=decoder_cell, beam_width=beam_width, output_layer=output_layer)
decoder_initial_state = tfa.seq2seq.tile_batch(encoder_final_state, multiplier=beam_width)

start_token = 1
start_tokens = tfa.seq2seq.tile_batch(tf.constant([start_token]), multiplier=beam_width)
end_token = 2

outputs, _, _ = decoder(
    decoder_input_embeddings, start_tokens=start_tokens, end_token=end_token,
    initial_state=decoder_initial_state)

probas = tf.nn.softmax(outputs.rnn_output)

model = keras.Model(inputs=[encoder_inputs, decoder_inputs, decoder_input_lengths], outputs=probas)

调用decoder引发错误:

ValueError: Dimension size must be evenly divisible by 4 but is 9
        Number of ways to split should evenly divide the split dimension for '{{node beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/split}} = Split[T=DT_FLOAT, num_split=4](beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/split/split_dim, beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/BiasAdd)' with input shapes: [], [9,9,2048] and with computed input tensors: input[0] = <1>.

有人可以解释这里的错误是什么以及如何解决它吗?

TensorFlow 2.3.0
TensorFlow 插件 0.12.1

4

0 回答 0