26

我正在尝试将列表传递给feed_dict,但是我在这样做时遇到了麻烦。说我有:

inputs = 10 * [tf.placeholder(tf.float32, shape=(batch_size, input_size))]

其中输入被输入到outputs我想要计算的某个函数中。因此,为了在 tensorflow 中运行它,我创建了一个会话并运行以下命令:

sess.run(outputs, feed_dict = {inputs: data}) 
#data is my list of inputs, which is also of length 10

但我得到一个错误,TypeError: unhashable type: 'list'. 但是,我能够像这样传递数据元素:

sess.run(outputs, feed_dict = {inputs[0]: data[0], ..., inputs[9]: data[9]}) 

所以我想知道是否有办法解决这个问题。我也尝试构建一个字典(使用for循环),但是这会导致字典只有一个元素,它们的关键是: tensorflow.python.framework.ops.Tensor at 0x107594a10

4

3 回答 3

47

这里有两个问题导致问题:

第一个问题是Session.run()调用只接受少数类型作为feed_dict. 特别是,支持张量列表作为键,因此您必须将每个张量作为单独的键。*一种方便的方法是使用字典推导:

inputs = [tf.placeholder(...), ...]
data = [np.array(...), ...]
sess.run(y, feed_dict={i: d for i, d in zip(inputs, data)})

第二个问题是10 * [tf.placeholder(...)]Python 中的语法创建了一个包含十个元素的列表,其中每个元素都是相同的张量对象(即具有相同的name属性,相同的id属性,并且如果您使用 比较列表中的两个元素,则引用相同inputs[i] is inputs[j]) . 这解释了为什么当您尝试使用列表元素作为键创建字典时,最终得到的字典只有一个元素——因为所有列表元素都是相同的。

如您所愿,要创建 10 个不同的占位符张量,您应该改为执行以下操作:

inputs = [tf.placeholder(tf.float32, shape=(batch_size, input_size))
          for _ in xrange(10)]

如果您打印此列表的元素,您会看到每个元素都是具有不同名称的张量。


编辑: *您现在可以将元组作为 a 的键传递feed_dict,因为这些可以用作字典键。

于 2015-11-13T03:04:05.103 回答
6

这是一个正确的例子:

batch_size, input_size, n = 2, 3, 2
# in your case n = 10
x = tf.placeholder(tf.types.float32, shape=(n, batch_size, input_size))
y = tf.add(x, x)

data = np.random.rand(n, batch_size, input_size)

sess = tf.Session()
print sess.run(y, feed_dict={x: data})

这是我在你的方法中看到的一个奇怪的事情。出于某种原因10 * [tf.placeholder(...)],您使用了 10 个大小的张量(batch_size, input_size)。不知道你为什么要这样做,如果你可以在 3 阶的张量上创建(第一个维度是 10)。

因为你有一个张量列表(而不是张量),所以你不能将你的数据提供给这个列表(但在我的情况下,我可以提供给我的张量)。

于 2015-11-13T02:11:21.420 回答
2

feed_dict 可以通过预先准备字典来提供,如下所示

n = 10
input_1 = [tf.placeholder(...) for _ in range(n)]
input_2 = tf.placeholder(...)
data_1 = [np.array(...) for _ in range(n)]
data_2 = np.array(...)


feed_dictionary = {}
for i in range(n):
    feed_dictionary[input_1[i]] = data_1[i]
feed_dictionary[input_2] = data_2
sess.run(y, feed_dict=feed_dictionary)
于 2019-01-09T06:42:14.537 回答