2

When using the tensorflow's Dataset API Iterator, my goal is to define an RNN that operates on the iterator's get_next() tensors as its input (see (1) in the code).

However, simply defining the dynamic_rnn with get_next() as its input results in an error: ValueError: Initializer for variable rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.

Now I know that one workaround is to simply create a placeholder for next_batch and then eval() the tensor (because you can't pass the tensor itself) and pass it using feed_dict (see X and (2) in the code). However, if I understand it correctly, this is not an efficient solution as we first evaluate and then reinitialize the tensor.

Is there any way to either:

  1. Define the dynamic_rnn directly on top of the output of the Iterator;

or:

  1. Somehow directly pass the existing get_next() tensor to the placeholder that is the input of dynamic_rnn?

Full working example; the (1) version is what I would like to work but it doesn't, while (2) is the workaround that does work.

import tensorflow as tf

from tensorflow.contrib.rnn import BasicLSTMCell
from tensorflow.python.data import Iterator

data = [ [[1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ]
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(2)
iterator = Iterator.from_structure(dataset.output_types,
                                   dataset.output_shapes)
next_batch = iterator.get_next()
iterator_init = iterator.make_initializer(dataset)

# (2):
X = tf.placeholder(tf.float32, shape=(None, 3, 1))

cell = BasicLSTMCell(num_units=8)

# (1):
# outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, next_batch, dtype=tf.float32)

# (2):
outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    sess.run(iterator_init)

    # (1):
    # o, s = sess.run([outputs, states])
    # o, s = sess.run([outputs, states])

    # (2):
    o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
    o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})

(Using tensorflow 1.4.0, Python 3.6.)

Thank you very much :)

4

1 回答 1

5

原来这个神秘的错误很可能是 tensorflow 中的一个错误,请参阅https://github.com/tensorflow/tensorflow/issues/14729。更具体地说,错误实际上来自提供错误的数据类型(在我上面的示例中,data数组包含int32值但它应该包含浮点数)。

tensorflow 应该返回:(参见1) ,而不是得到ValueError: Initializer for variable rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct错误。

TypeError: Tensors in list passed to 'values' of 'ConcatV2' Op have types [int32, float32] that don't all match.

要解决此问题,只需更改
data = [ [[1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ]

data = np.array([[ [1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ], dtype=np.float32)

然后以下代码应正常工作:

import tensorflow as tf
import numpy as np

from tensorflow.contrib.rnn import BasicLSTMCell
from tensorflow.python.data import Iterator

data = np.array([[ [1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ], dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(2)
iterator = Iterator.from_structure(dataset.output_types,
                                   dataset.output_shapes)
next_batch = iterator.get_next()
iterator_init = iterator.make_initializer(dataset)

# (2):
# X = tf.placeholder(tf.float32, shape=(None, 3, 1))

cell = BasicLSTMCell(num_units=8)

# (1):
outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, next_batch, dtype=tf.float32)

# (2):
# outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    sess.run(iterator_init)

    # (1):
    o, s = sess.run([outputs, states])
    o, s = sess.run([outputs, states])

    # (2):
    # o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
    # o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
于 2017-11-21T09:25:03.897 回答