我认为我使用tensorflow_addons.seq2seq2.BeamSearchDecoder
不正确。我收到了关于不良张量形状的各种神秘信息。我想这个问题与它们的平铺方式有关initial_state
,start_tokens
或者可能与未指定的批量大小有关。文档没有正确指定这些参数应该如何工作。
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
vocab_size = 1000
embedding_size = 100
encoder_inputs = keras.layers.Input(shape=[None], dtype=tf.int32)
decoder_inputs = keras.layers.Input(shape=[None], dtype=tf.int32)
decoder_input_lengths = keras.layers.Input(shape=[], dtype=tf.int32)
embedding_layer = keras.layers.Embedding(vocab_size, embedding_size)
encoder_input_embeddings = embedding_layer(encoder_inputs)
decoder_input_embeddings = embedding_layer(decoder_inputs)
encoder = keras.layers.LSTM(512, return_state=True)
_, encoder_final_state_h, encoder_final_state_c = encoder(encoder_input_embeddings)
encoder_final_state = [encoder_final_state_h, encoder_final_state_c]
decoder_cell = keras.layers.LSTMCell(512)
output_layer = keras.layers.Dense(vocab_size)
beam_width = 3
decoder = tfa.seq2seq.BeamSearchDecoder(cell=decoder_cell, beam_width=beam_width, output_layer=output_layer)
decoder_initial_state = tfa.seq2seq.tile_batch(encoder_final_state, multiplier=beam_width)
start_token = 1
start_tokens = tfa.seq2seq.tile_batch(tf.constant([start_token]), multiplier=beam_width)
end_token = 2
outputs, _, _ = decoder(
decoder_input_embeddings, start_tokens=start_tokens, end_token=end_token,
initial_state=decoder_initial_state)
probas = tf.nn.softmax(outputs.rnn_output)
model = keras.Model(inputs=[encoder_inputs, decoder_inputs, decoder_input_lengths], outputs=probas)
调用decoder
引发错误:
ValueError: Dimension size must be evenly divisible by 4 but is 9
Number of ways to split should evenly divide the split dimension for '{{node beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/split}} = Split[T=DT_FLOAT, num_split=4](beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/split/split_dim, beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/BiasAdd)' with input shapes: [], [9,9,2048] and with computed input tensors: input[0] = <1>.
有人可以解释这里的错误是什么以及如何解决它吗?
TensorFlow 2.3.0
TensorFlow 插件 0.12.1