1

我正在尝试根据使用 feed_dict 输入的输入的维度来 tf.split 一个张量(每个批次的输入变化的维度)。目前我一直收到一个错误,说张量不能用“维度”分割。有没有办法获取维度的值并使用它进行拆分?

谢谢!

input_d = tf.placeholder(tf.int32, [None, None], name="input_d")

# toy feed dict
feed = {
    input_d: [[20,30,40,50,60],[2,3,4,5,-1]] # document
}

W_embeddings = tf.get_variable(shape=[vocab_size, embedding_dim], \
                  initializer=tf.random_uniform_initializer(-0.01, 0.01),\
                  name="W_embeddings")  
document_embedding = tf.gather(W_embeddings, input_d)

timesteps_d = document_embedding.get_shape()[1]
doc_input = tf.split(1, timesteps_d, document_embedding)
4

1 回答 1

0

tf.split接受一个 Python 整数作为num_split参数。但是,document_embedding.get_shape()返回 aTensorShapedocument_embedding.get_shape()[1]给出一个Dimension实例,因此您会收到一条错误消息“无法使用 Dimension 拆分”。

试试看timestep_ds = document_embedding.get_shape().as_list()[1],这个语句应该给你一个 python 整数。

以下是tf.splittf.Tensor.get_shape的一些相关文档

于 2016-11-09T04:51:52.833 回答