4

我有一个embedded_chars数组,使用以下代码创建:

self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")

W = tf.Variable( 
    tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
    name="W"
    )
self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x) 

input_x如果我只有embedded_charsand ,我想得到数组W

我怎么才能得到它?

谢谢!

4

1 回答 1

5

W您可以使用和中的嵌入向量之间的余弦距离embedded_chars

# assume embedded_chars.shape == (batch_size, embedding_size)
emb_distances = tf.matmul( # shape == (vocab_size, batch_size)
    tf.nn.l2_normalize(W, dim=1),
    tf.nn.l2_normalize(embedded_chars, dim=1),
    transpose_b=True)
token_ids = tf.argmax(emb_distances, axis=0) # shape == (batch_size)

emb_distances是 L2 归一化矩阵W和的点积,它与 中的所有向量与 中的所有向量transpose(embedded_chars)之间的余弦距离相同。argmax 只是简单地为我们提供了每列中最大值的索引。Wembedded_charsemb_distances

@Yao Zhang:如果所有嵌入W都不同,应该是不同的,那么这将为您提供正确的结果:余弦距离始终在 [-1, 1] 和 cos(vector_a, vector_a) == 1 之间。

请注意,通常您不需要进行从嵌入到标记索引的这种转换:通常您可以直接将作为第二个参数传递的张量的值传递给tf.nn.embedding_embedding_lookup,这是标记索引的张量。

于 2017-10-04T14:57:52.293 回答