我目前有要使用嵌入的 onehot 编码。但是,当我打电话时
embed=tf.nn.embedding_lookup(embeddings, train_data)
print(embed.get_shape())
嵌入数据形状(11、32、729、128)
这个形状应该是 (11, 32, 128) 但它给了我错误的尺寸,因为 train_data 是 onehot 编码的。
train_data2=tf.matmul(train_data,tf.range(729))
给我错误:
ValueError: Shape must be rank 2 but is rank 3
请帮帮我!谢谢。