我正在尝试在 tf 1.1 中为 seq2seq 模型生成一个简单的代码。我不确定参数“查询机制的深度”是什么。我在创建注意力机制时遇到错误,说:
TypeError: int() argument must be a string, a bytes-like object or a number, not 'TensorShape'
这是我的代码。我在正确的轨道上吗?我找不到任何详细的文档。
import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple, BasicLSTMCell, DropoutWrapper, MultiRNNCell, EmbeddingWrapper, static_rnn
import tensorflow.contrib.seq2seq as seq2seq
import attention_wrapper as wrapper
tf.reset_default_graph()
try:
sess.close()
except:
pass
sess = tf.InteractiveSession()
## Place holders
encode_input = [tf.placeholder(tf.int32,
shape=(None,),
name = "ei_%i" %i)
for i in range(input_seq_length)]
labels = [tf.placeholder(tf.int32,
shape=(None,),
name = "l_%i" %i)
for i in range(output_seq_length)]
decode_input = [tf.zeros_like(encode_input[0], dtype=np.int32, name="GO")] + labels[:-1]
############ Encoder
lstm_cell = BasicLSTMCell(embedding_dim)
encoder_cell = EmbeddingWrapper(lstm_cell, embedding_classes=input_vocab_size, embedding_size=embedding_dim)
encoder_outputs, encoder_state = static_rnn(encoder_cell, encode_input, dtype=tf.float32)
############ Decoder
# Attention Mechanisms. Bahdanau is additive style attention
attn_mech = tf.contrib.seq2seq.BahdanauAttention(
num_units = input_seq_length, # depth of query mechanism
memory = encoder_outputs, # hidden states to attend (output of RNN)
normalize=False, # normalize energy term
name='BahdanauAttention')
lstm_cell_decoder = BasicLSTMCell(embedding_dim)
# Attention Wrapper: adds the attention mechanism to the cell
attn_cell = wrapper.AttentionWrapper(
cell = lstm_cell_decoder,# Instance of RNNCell
attention_mechanism = attn_mech, # Instance of AttentionMechanism
attention_size = embedding_dim, # Int, depth of attention (output) tensor
attention_history=False, # whether to store history in final output
name="attention_wrapper")
# Decoder setup
decoder = tf.contrib.seq2seq.BasicDecoder(
cell = lstm_cell_decoder,
helper = helper, # A Helper instance
initial_state = encoder_state, # initial state of decoder
output_layer = None) # instance of tf.layers.Layer, like Dense
# Perform dynamic decoding with decoder object
outputs, final_state = tf.contrib.seq2seq.dynamic_decode(decoder)