我正在尝试根据使用 feed_dict 输入的输入的维度来 tf.split 一个张量(每个批次的输入变化的维度)。目前我一直收到一个错误,说张量不能用“维度”分割。有没有办法获取维度的值并使用它进行拆分?
谢谢!
input_d = tf.placeholder(tf.int32, [None, None], name="input_d")
# toy feed dict
feed = {
input_d: [[20,30,40,50,60],[2,3,4,5,-1]] # document
}
W_embeddings = tf.get_variable(shape=[vocab_size, embedding_dim], \
initializer=tf.random_uniform_initializer(-0.01, 0.01),\
name="W_embeddings")
document_embedding = tf.gather(W_embeddings, input_d)
timesteps_d = document_embedding.get_shape()[1]
doc_input = tf.split(1, timesteps_d, document_embedding)