我最近一直在研究 tensorflow。我正在编码 seq2seq 模型。我正在编写一个条件来选择 API 提供的帮助程序。
当我使用以下代码时,我面临错误。
Training Helper
helper1 = tf.contrib.seq2seq.TrainingHelper(inputs = decoder_embedded_input,sequence_length = dec_seqLen,time_major=True)
helper2 = tf.contrib.seq2seq.GreedyEmbeddingHelper(output_embedding,
tf.fill([batchSize], outT2N['<GO>']),
outT2N['<EOS>'])
helperDecider = tf.placeholder(tf.bool)
# when 0 : helper2
# when 1 : helper1
helper = tf.cond(helperDecider,helper1,helper2)
我得到错误助手必须是可调用的,所以我将代码更改为
def helper1():
return tf.contrib.seq2seq.TrainingHelper(inputs = dec_embedded_input,sequence_length = dec_seqLen,time_major=True)
def helper2():
return tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddingMatrixOut,
tf.fill([batchSize], outT2N['<GO>']),
outT2N['<EOS>'])
helperDecider = tf.placeholder(tf.bool)
# when 0 : helper2
# when 1 : helper1
helper = tf.cond(helperDecider,helper1,helper2)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell,helper,encoder_final_state ,output_layer = projection_layer)
现在它正在抛出错误,
Expected binary or unicode string, got <tensorflow.contrib.seq2seq.python.ops.helper.TrainingHelper object at 0x7fc32b96b908>
所以,最后我选择了旧的 if-else 并且它应该像它一样工作。我只需要使用以下代码是否有效。
#Training Helper
helper1 = tf.contrib.seq2seq.TrainingHelper(inputs = dec_embedded_input,sequence_length = dec_seqLen,time_major=True)
helper2 = tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddingMatrixOut,
tf.fill([batchSize], outT2N['<GO>']),
outT2N['<EOS>'])
helperDecider = tf.placeholder(tf.bool)
# when 0 : helper2
# when 1 : helper1
if someCondition:
helper = helper1
else:
helper = helper2
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell,helper,encoder_final_state ,output_layer = projection_layer)
可能的错误是由于硬编码,我无法在运行时更改为其他助手。有人可以建议一种替代方法吗?