0

在使用 seq2seq 架构的 NMT 中,在推理过程中,我们需要在训练阶段训练的嵌入变量作为 GreedyEmbeddingHelper 或 BeamSearchDecoder 的输入。

问题是,在使用 Estimator API 进行训练和推断的上下文中,我们如何提取这个经过训练的嵌入变量以用于预测?

4

1 回答 1

0

我想出了一个基于以下 stackoverflow答案的解决方案。对于预测阶段,您可以使用 tf.contrib.framework.load_variable 从经过训练和保存的 Tensorflow 模型中检索嵌入变量,如下所示:

if mode == tf.estimator.ModeKeys.PREDICT:
    embeddings = tf.constant(tf.contrib.framework.load_variable('.','embed/embeddings'))
    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding=embeddings,
    start_tokens=tf.fill([batch_size], 1),end_token=0)

因此,就我而言,我从包含已保存模型的同一文件夹中运行代码,并且我的变量名是“嵌入/嵌入”。请注意,这仅适用于通过 tensorflow 模型训练的嵌入。否则,请参阅上面链接的答案。

要使用估算器 API 查找变量名称,您可以使用 get_variable_names() 方法获取保存在图中的所有变量名称的列表。

于 2018-03-26T09:36:41.210 回答