4

TensorFlow新手在这里。我正在尝试构建一个 RNN。我的输入数据是一组大小的向量实例,instance_size表示每个时间步中一组粒子的 (x,y) 位置。(由于实例已经具有语义内容,它们不需要嵌入。)目标是在下一步学习预测粒子的位置。

按照RNN 教程并稍微修改包含的 RNN 代码,我创建了一个或多或少像这样的模型(省略了一些细节):

inputs, self._input_data = tf.placeholder(tf.float32, [batch_size, num_steps, instance_size])
self._targets = tf.placeholder(tf.float32, [batch_size, num_steps, instance_size])

with tf.variable_scope("lstm_cell", reuse=True):
  lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size, forget_bias=0.0)
  if is_training and config.keep_prob < 1:
    lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
        lstm_cell, output_keep_prob=config.keep_prob)
  cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)

self._initial_state = cell.zero_state(batch_size, tf.float32)

from tensorflow.models.rnn import rnn
inputs = [tf.squeeze(input_, [1])
          for input_ in tf.split(1, num_steps, inputs)]
outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)

output = tf.reshape(tf.concat(1, outputs), [-1, hidden_size])
softmax_w = tf.get_variable("softmax_w", [hidden_size, instance_size])
softmax_b = tf.get_variable("softmax_b", [instance_size])
logits = tf.matmul(output, softmax_w) + softmax_b
loss = position_squared_error_loss(
    tf.reshape(logits, [-1]),
    tf.reshape(self._targets, [-1]),
)
self._cost = cost = tf.reduce_sum(loss) / batch_size
self._final_state = state

然后我创建一个saver = tf.train.Saver(),迭代数据以使用给定的run_epoch()方法训练它,并用saver.save(). 到目前为止,一切都很好。

但是我如何实际使用训练好的模型呢?教程到此停止。从docs on 开始tf.train.Saver.restore(),为了读回变量,我需要设置与保存变量时运行的完全相同的图表,或者有选择地恢复特定变量。无论哪种方式,这意味着我的新模型将需要 size 的输入batch_size x num_steps x instance_size。然而,我现在想要的只是在一个大小的输入上对模型进行一次前向传递,num_steps x instance_size并读出一个instance_size大小的结果(下一个时间步的预测);换句话说,我想创建一个模型,它接受与我训练的张量不同的张量。我可以通过将现有模型传递给我的预期数据来解决它batch_size次,但这似乎不是最佳实践。最好的方法是什么?

4

1 回答 1

2

您必须创建一个具有相同结构但使用 的新图形,batch_size = 1并使用 导入保存的变量tf.train.Saver.restore()。您可以查看他们如何在 ptb_word_lm.py 中定义具有可变批量大小的多个模型:https ://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/rnn/ptb/ptb_word_lm.py

因此,您可以拥有一个单独的文件,例如,在其中使用所需的 batch_size 实例化图形,然后恢复保存的变量。然后你可以执行你的图表。

于 2016-04-08T23:40:20.673 回答