我尝试按照教程https://www.tensorflow.org/tutorials/seq2seq在 tensorflow 中实现编码器-解码器模型 我做了一个简单的编码器-解码器模型
from __future__ import print_function
import numpy as np
import tensorflow as tf
from utils2 import *
def def_length(sequence):
used = tf.sign(tf.reduce_max(tf.abs(sequence), 2))
dlength = tf.reduce_sum(used, 0)
dlength = tf.cast(dlength, tf.int32)
return dlength
tf.reset_default_graph()
char2idx, idx2char = load_vocab()
x, y, z = get_batch()
num_mels = y.shape[-1]
print('x shape', x.shape)
print('y shape', y.shape)
inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='inputs')
decoder_inputs = tf.placeholder(shape=(None, None, num_mels), dtype=tf.float32, name='decoder_inputs')
decoder_lengths = tf.placeholder(shape=None, dtype=tf.int32, name='decoder_lengths')
decoder_targets = tf.placeholder(shape=(None, None, num_mels), dtype=tf.float32, name='decoder_targets')
lookup_table = tf.get_variable('lookup_table',
shape=[len(p.vocab), p.embed_size],
dtype=tf.float32)
encoder_inputs = tf.nn.embedding_lookup(lookup_table, inputs)
# encoder
encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(p.gru_units)
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs,
dtype=tf.float32, sequence_length=def_length(encoder_inputs),
time_major=True, scope='encoder')
# decoder
decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(p.gru_units)
decoder_outputs, decoder_state = tf.nn.dynamic_rnn(decoder_cell, decoder_inputs,
initial_state=encoder_state,
sequence_length=def_length(decoder_inputs),
dtype=tf.float32, time_major=True, scope='decoder')
outputs = tf.layers.dense(decoder_outputs, num_mels)
loss = tf.reduce_mean(tf.abs(y - outputs))
train_op = tf.train.AdamOptimizer().minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
outs = sess.run(train_op,
feed_dict={
inputs: x,
decoder_inputs: y,
})
print('outputs shape', outs)
它有效。我想为解码器添加注意力机制。
# decoder with attention
attention_states = tf.transpose(encoder_outputs, [1, 0, 2])
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
p.gru_units, attention_states,
memory_sequence_length=def_length(encoder_inputs)
)
但是当我跑步时
outs = sess.run(attention_mechanism, feed_dict={ inputs: x, decoder_inputs: y, }) 我得到了错误
Traceback (most recent call last):
File "C:\Users\Admin\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 282, in __init__
fetch, allow_tensor=True, allow_operation=True))
File "C:\Users\Admin\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 3590, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "C:\Users\Admin\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 3679, in _as_graph_element_locked
types_str))
TypeError: Can not convert a LuongAttention into a Tensor or Operation.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "D:/Programming/Projects/tasks/text_to_speech/tts/tf_seq2seq.py", line 60, in <module>
decoder_inputs: y,
File "C:\Users\Admin\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 900, in run
run_metadata_ptr)
File "C:\Users\Admin\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
File "C:\Users\Admin\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 427, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "C:\Users\Admin\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 253, in for_fetch
return _ElementFetchMapper(fetches, contraction_fn)
File "C:\Users\Admin\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 286, in __init__
(fetch, type(fetch), str(e)))
TypeError: Fetch argument <tensorflow.contrib.seq2seq.python.ops.attention_wrapper.LuongAttention object at 0x00000000127C1898> has invalid type <class 'tensorflow.contrib.seq2seq.python.ops.attention_wrapper.LuongAttention'>, must be a string or Tensor. (Can not convert a LuongAttention into a Tensor or Operation.)
TypeError:无法将 LuongAttention 转换为张量或操作。我检查了注意力状态。它是一个形状为 (10, 155, 512) 的张量。p.gru_units 是等于 512 的整数。我不明白什么不能转换成张量。提前致谢。