0

我最近一直在研究 tensorflow。我正在编码 seq2seq 模型。我正在编写一个条件来选择 API 提供的帮助程序。

当我使用以下代码时,我面临错误。

Training Helper
helper1 = tf.contrib.seq2seq.TrainingHelper(inputs = decoder_embedded_input,sequence_length = dec_seqLen,time_major=True)


helper2 = tf.contrib.seq2seq.GreedyEmbeddingHelper(output_embedding,
                                                   tf.fill([batchSize], outT2N['<GO>']),
                                                   outT2N['<EOS>'])
helperDecider = tf.placeholder(tf.bool)
# when 0 : helper2
# when 1 : helper1
helper = tf.cond(helperDecider,helper1,helper2)

我得到错误助手必须是可调用的,所以我将代码更改为

def helper1():
    return tf.contrib.seq2seq.TrainingHelper(inputs = dec_embedded_input,sequence_length = dec_seqLen,time_major=True)

def helper2():
    return tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddingMatrixOut,
                                                   tf.fill([batchSize], outT2N['<GO>']),
                                                   outT2N['<EOS>'])
helperDecider = tf.placeholder(tf.bool)
# when 0 : helper2
# when 1 : helper1

helper = tf.cond(helperDecider,helper1,helper2)

decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell,helper,encoder_final_state ,output_layer = projection_layer)

现在它正在抛出错误,

Expected binary or unicode string, got <tensorflow.contrib.seq2seq.python.ops.helper.TrainingHelper object at 0x7fc32b96b908>

所以,最后我选择了旧的 if-else 并且它应该像它一样工作。我只需要使用以下代码是否有效。

#Training Helper
helper1 = tf.contrib.seq2seq.TrainingHelper(inputs = dec_embedded_input,sequence_length = dec_seqLen,time_major=True)


helper2 = tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddingMatrixOut,
                                                   tf.fill([batchSize], outT2N['<GO>']),
                                                   outT2N['<EOS>'])
helperDecider = tf.placeholder(tf.bool)
# when 0 : helper2
# when 1 : helper1

if someCondition:
    helper = helper1
else:
    helper = helper2


decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell,helper,encoder_final_state ,output_layer = projection_layer)

可能的错误是由于硬编码,我无法在运行时更改为其他助手。有人可以建议一种替代方法吗?

4

1 回答 1

0

如果条件在 tensorflow 中无效,则为正常,因为一旦评估图表就定义了图表(它将始终评估为 true,并且 helper 将设置为 helper1)。

为此,可以使用tf.cond来定义条件图,就像您在第一个片段中所做的那样。你得到的错误与错误使用 tf.cond 函数有关,它应该是这样的

helper = tf.cond(helperDecider,lambda: helper1, lambda: helper2)

函数的第二个和第三个参数应该是一个它不是的函数。希望这有效..

于 2018-04-17T18:54:27.613 回答