43

我正在尝试一个行为不符合预期的操作。

graph = tf.Graph()
with graph.as_default():
  train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
  embeddings = tf.Variable(
    tf.random_uniform([50000, 64], -1.0, 1.0))
  embed = tf.nn.embedding_lookup(embeddings, train_dataset)
  embed = tf.reduce_sum(embed, reduction_indices=0)

所以我需要知道 Tensor 的尺寸embed。我知道它可以在运行时完成,但是对于这样一个简单的操作来说工作量太大了。更简单的方法是什么?

4

6 回答 6

57

我看到大多数人对此感到困惑tf.shape(tensor)tensor.get_shape() 让我们澄清一下:

  1. tf.shape

tf.shape用于动态形状。如果您的张量的形状是可变的,请使用它。一个例子:一个输入是一个宽度和高度可变的图像,我们想把它调整到它的一半大小,那么我们可以这样写:
new_height = tf.shape(image)[0] / 2

  1. tensor.get_shape

tensor.get_shape用于固定形状,这意味着可以在图中推导出张量的形状。

结论: tf.shape几乎可以在任何地方使用,但t.get_shape只能从图形中推断出形状。

于 2017-04-08T06:23:07.180 回答
46

Tensor.get_shape()这篇文章

从文档

c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(c.get_shape())
==> TensorShape([Dimension(2), Dimension(3)])
于 2016-05-01T13:06:54.123 回答
13

访问值的函数:

def shape(tensor):
    s = tensor.get_shape()
    return tuple([s[i].value for i in range(0, len(s))])

例子:

batch_size, num_feats = shape(logits)
于 2017-01-20T19:48:36.550 回答
5

只需在构建图(操作)后打印出嵌入而不运行:

import tensorflow as tf

...

train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
embeddings = tf.Variable(
    tf.random_uniform([50000, 64], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_dataset)
print (embed)

这将显示嵌入张量的形状:

Tensor("embedding_lookup:0", shape=(128, 2, 64), dtype=float32)

通常,最好在训练模型之前检查所有张量的形状。

于 2016-05-01T13:08:34.273 回答
4

让我们把它简单得像地狱。如果你想要一个单一的数字来表示维度的数量,2, 3, 4, etc.,那么只需使用tf.rank(). 但是,如果您想要张量的确切形状,请使用tensor.get_shape()

with tf.Session() as sess:
   arr = tf.random_normal(shape=(10, 32, 32, 128))
   a = tf.random_gamma(shape=(3, 3, 1), alpha=0.1)
   print(sess.run([tf.rank(arr), tf.rank(a)]))
   print(arr.get_shape(), ", ", a.get_shape())     


# for tf.rank()    
[4, 3]

# for tf.get_shape()
Output: (10, 32, 32, 128) , (3, 3, 1)
于 2017-04-08T16:57:10.967 回答
1

方法 tf.shape 是一种 TensorFlow 静态方法。但是,Tensor 类也有 get_shape 方法。看

https://www.tensorflow.org/api_docs/python/tf/Tensor#get_shape

于 2017-11-09T19:38:49.130 回答