您可以在答案末尾找到所有代码。
我认为您的大多数问题(为什么使用 Softmax、如何使用预训练嵌入层等)都已得到解答。但是,由于您仍在等待简洁的代码从种子生成生成的文本,因此我在这里尝试报告我自己最终是如何做到的。
我从官方的 Tensorflow 教程开始苦苦挣扎,直到我可以轻松地从生成的模型中生成单词。幸运的是,在对您在问题中提到的几乎所有答案进行了一些回答之后,我对问题(和解决方案)有了更好的了解。这可能包含错误,但至少它会运行并生成一些文本......
给定句子的前几个单词,如何使用生成的模型实际生成下一个单词建议?
我会将下一个单词建议包装在一个循环中,以生成一个完整的句子,但您很容易将其简化为一个单词。
假设您遵循了 tensorflow(撰写本文时为 v1.4)给出的当前教程,这将在训练后保存模型。
然后我们要做的就是从磁盘加载它,并编写一个函数,该函数接受这个模型和一些种子输入并返回生成的文本。
从保存的模型生成文本
我假设我们在一个新的 python 脚本中编写了所有这些代码。底部的整个脚本作为回顾,在这里我解释了主要步骤。
第一个必要步骤
FLAGS = tf.flags.FLAGS
FLAGS.model = "medium" # or whatever size you used
现在,非常重要的是,我们创建字典来将 id 映射到单词,反之亦然(因此我们不必读取整数列表......)。
word_to_id = reader._build_vocab('../data/ptb.train.txt') # here we load the word -> id dictionnary ()
id_to_word = dict(zip(word_to_id.values(), word_to_id.keys())) # and transform it into id -> word dictionnary
_, _, test_data, _ = reader.ptb_raw_data('../data')
然后我们加载配置类,也将num_steps
and设置batch_size
为 1,因为我们希望一次采样 1 个单词,而 LSTM 一次也将处理 1 个单词。还即时创建输入实例:
eval_config = get_config()
eval_config.num_steps = 1
eval_config.batch_size = 1
model_input = PTBInput(eval_config, test_data)
建筑图
要加载保存的模型(由Supervisor.saver
教程中的模块保存),我们首先需要重建图形(使用类很容易PTBModel
),它必须使用与训练时相同的配置:
sess = tf.Session()
initializer = tf.random_uniform_initializer(-eval_config.init_scale, eval_config.init_scale)
# not sure but seems to need the same name for variable scope as when saved ....!!
with tf.variable_scope("Model", reuse=None, initializer=initializer):
tf.global_variables_initializer()
mtest = PTBModel(is_training=False, config=eval_config, input=model_input)
恢复保存的重量:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('../Whatever_folder_you_saved_in')) # the path must point to the hierarchy where your 'checkpoint' file is
...从给定的种子中采样单词:
首先,我们需要模型包含对 logits 输出的访问,或者更准确地说是整个词汇表的概率分布。所以在ptb_lstm.py
文件中添加以下行:
# the line goes somewhere below the reshaping "logits = tf.reshape(logits, [self.batch_size, ..."
self.probas = tf.nn.softmax(logits, name="probas")
然后我们可以设计一些采样函数(你可以在这里随意使用任何你喜欢的东西,最好的方法是使用趋于平坦或锐化分布的温度进行采样),这是一个基本的随机采样方法:
def sample_from_pmf(probas):
t = np.cumsum(probas)
s = np.sum(probas)
return int(np.searchsorted(t, np.random.rand(1) * s))
最后是一个函数,它接受一个种子、你的模型、将单词映射到 id 的字典,反之亦然,作为输入和输出生成的文本字符串:
def generate_text(session, model, word_to_index, index_to_word,
seed='</s>', n_sentences=10):
sentence_cnt = 0
input_seeds_id = [word_to_index[w] for w in seed.split()]
state = session.run(model.initial_state)
# Initiate network with seeds up to the before last word:
for x in input_seeds_id[:-1]:
feed_dict = {model.initial_state: state,
model.input.input_data: [[x]]}
state = session.run([model.final_state], feed_dict)
text = seed
# Generate a new sample from previous, starting at last word in seed
input_id = [[input_seeds_id[-1]]]
while sentence_cnt < n_sentences:
feed_dict = {model.input.input_data: input_id,
model.initial_state: state}
probas, state = session.run([model.probas, model.final_state],
feed_dict=feed_dict)
sampled_word = sample_from_pmf(probas[0])
if sampled_word == word_to_index['</s>']:
text += '.\n'
sentence_cnt += 1
else:
text += ' ' + index_to_word[sampled_word]
input_wordid = [[sampled_word]]
return text
TL;博士
不要忘记添加以下行:
self.probas = tf.nn.softmax(logits, name='probas')
在ptb_lstm.py
文件中,在类的__init__
定义中PTBModel
,行之后的任何位置logits = tf.reshape(logits, [self.batch_size, self.num_steps, vocab_size])
。
整个脚本,只需从您拥有的同一目录中运行它reader.py
,ptb_lstm.py
:
import reader
import numpy as np
import tensorflow as tf
from ptb_lstm import PTBModel, get_config, PTBInput
FLAGS = tf.flags.FLAGS
FLAGS.model = "medium"
def sample_from_pmf(probas):
t = np.cumsum(probas)
s = np.sum(probas)
return int(np.searchsorted(t, np.random.rand(1) * s))
def generate_text(session, model, word_to_index, index_to_word,
seed='</s>', n_sentences=10):
sentence_cnt = 0
input_seeds_id = [word_to_index[w] for w in seed.split()]
state = session.run(model.initial_state)
# Initiate network with seeds up to the before last word:
for x in input_seeds_id[:-1]:
feed_dict = {model.initial_state: state,
model.input.input_data: [[x]]}
state = session.run([model.final_state], feed_dict)
text = seed
# Generate a new sample from previous, starting at last word in seed
input_id = [[input_seeds_id[-1]]]
while sentence_cnt < n_sentences:
feed_dict = {model.input.input_data: input_id,
model.initial_state: state}
probas, state = sess.run([model.probas, model.final_state],
feed_dict=feed_dict)
sampled_word = sample_from_pmf(probas[0])
if sampled_word == word_to_index['</s>']:
text += '.\n'
sentence_cnt += 1
else:
text += ' ' + index_to_word[sampled_word]
input_wordid = [[sampled_word]]
print(text)
if __name__ == '__main__':
word_to_id = reader._build_vocab('../data/ptb.train.txt') # here we load the word -> id dictionnary ()
id_to_word = dict(zip(word_to_id.values(), word_to_id.keys())) # and transform it into id -> word dictionnary
_, _, test_data, _ = reader.ptb_raw_data('../data')
eval_config = get_config()
eval_config.batch_size = 1
eval_config.num_steps = 1
model_input = PTBInput(eval_config, test_data, name=None)
sess = tf.Session()
initializer = tf.random_uniform_initializer(-eval_config.init_scale,
eval_config.init_scale)
with tf.variable_scope("Model", reuse=None, initializer=initializer):
tf.global_variables_initializer()
mtest = PTBModel(is_training=False, config=eval_config,
input_=model_input)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('../models'))
while True:
print(generate_text(sess, mtest, word_to_id, id_to_word, seed="this sentence is"))
try:
raw_input('press Enter to continue ...\n')
except KeyboardInterrupt:
print('\b\bQuiting now...')
break
更新
至于使用最近的 tensorflow(至少 1.6)恢复旧的检查点(对我来说是 6 个月前保存的模型,不确定当时使用的确切 TF 版本),它可能会引发一些关于未找到变量的错误(见评论)。在这种情况下,您应该使用此脚本更新您的检查点。
另外,请注意,对我来说,我必须进一步修改它,因为我注意到该saver.restore
函数正在尝试读取lstm_cell
变量,尽管我的变量被转换为basic_lstm_cell
也导致NotFound Error
. 因此,一个简单的解决方法是删除新名称中的checkpoint_convert.py
脚本,即第 72-73 行中的一个小改动。basic_
检查检查点中包含的变量名称的一种方便方法是(CKPT_FILE
是 , 之前的后缀.index
等.data0000-1000
):
reader = tf.train.NewCheckpointReader(CKPT_FILE)
reader.get_variable_to_shape_map()
通过这种方式,您可以验证您确实拥有正确的名称(或旧检查点版本中的错误名称)。