1

我想在我的 IR 项目中使用余弦相似度,但是因为向量的大小很大并且必须多次乘以浮点数,所以需要很长时间。

有什么方法可以更快地计算余弦相似度?

这是我的代码:

private double diffrence(HashMap<Integer, Float> hashMap,
 HashMap<Integer, Float> hashMap2 ) {
    Integer[] keys = new Integer[hashMap.size()];
    hashMap.keySet().toArray(keys);

     float ans = 0;

    for (int i = 0; i < keys.length; i++) {
        if (hashMap2.containsKey(keys[i])) {
             ans += hashMap.get(keys[i]) * hashMap2.get(keys[i]);

        }
    }

     float hashLength = 0;
    for (int i = 0; i < keys.length; i++) {
         hashLength += (hashMap.get(keys[i]) * hashMap.get(keys[i]));
    }
     hashLength = (float) Math.sqrt(hashLength);

    Integer[] keys2 = new Integer[hashMap2.size()];
    hashMap2.keySet().toArray(keys2);

     float hash2Length = 0;
    for (int i = 0; i < keys2.length; i++) {

         hash2Length += hashMap2.get(keys2[i]) * hashMap2.get(keys2[i]);

    }
     hash2Length = (float) Math.sqrt(hash2Length);

    return (float) (ans /(hash2Length*hashLength));
}
4

5 回答 5

9

通常在 IR 中,一个向量的非零元素比另一个向量少得多(通常查询向量是更稀疏的向量,但即使对于文档向量也是如此)。您可以通过循环遍历稀疏向量的键来节省时间,即较小的哈希映射,在较大的哈希映射中查找它们。

至于 pkacprzak 对查找表的建议和您的内存不足:意识到可以在余弦相似度计算之前完成标准化。对于每个向量,在存储它之前,计算它的范数并将每个元素除以它。然后,你可以计算一个点积并得到一个余弦相似度。

即,余弦相似度通常定义为

x·y / (||x|| × ||y||)

但这等于

(x / ||x||) · (y / ||y||)

其中/是元素划分。如果您每个都替换xx / ||x||,那么您只需要计算x·y.

如果你结合这两个建议,你会得到一个余弦相似度算法,它只对两个输入中较小的一个进行一个循环。

通过使用更智能的稀疏向量结构可以进行进一步的改进;哈希表在查找和迭代中都有很多开销。

于 2013-06-26T19:52:14.567 回答
1

Usually there are too many vectors to precompute the cosine similarity of each pair, but you could precompute the length of every vector and store it using a lookup table. This reduces a constant factor in computing the cosine similarity of two vectors - actually it saves a significant amount of time, because of a lot of floating point operations.

I'm assuming that you are not wasting memory by storing zeros in the vector.

于 2013-06-26T19:45:40.830 回答
1

除了按照其他人的建议对向量进行预规范化并假设您的向量列表没有改变之外,将它们转换为数组对一次(在相似性函数之外)并按键索引对它们进行排序,例如:

Integer[] keys = new Integer[hashMap.size()];
Float values[] = new Float[keys.size()];
int i = 0;
float norm = ...;    
for (Map.Entry<Integer, Float> entry : new TreeMap<Integer, Float>(hashMap).entrySet())
{
   keys[i] = entry.getKey();
   values[i++] = entry.getValue() / norm;
}

然后进行实际的相似度计算(假设您然后通过keys1, values, keys2,values2而不是两个HashMaps),您的最内层循环减少为:

float ans = 0;
int i,j = 0;
while (i < keys1.length && j < keys2.length)
{
  if (keys1[i] < keys2[j])
    ++i;
  else if (keys1[i] > keys2[j])
    ++j;
  else
    // we have the same key in 1 and 2
    ans += values1[i] * values2[j];
}

您甚至可以考虑将所有keysvalues所有向量连续存储在一个大数组中,intfloat在第一个位置保留另一个带有索引的数组:

int sumOfAllVectorLengths = ...;
int allKeys[] = new int[sumOfAllVectorLengths];
float allValues[] = new float[sumOfAllVectorLengths];
int firstPos = new int[numberOfVectors + 1]; 
firstPos[numberOfVectors] = sumOfAllVectorLengths;

int nextFirstPos = 0;
int index = 0;

for (HashMap<Integer, Float> vector : allVectors)
{
   firstPos[index] = nextFirstPos;

   float norm = ...;    
   for (Map.Entry<Integer, Float> entry : new TreeMap<Integer, Float>(hashMap).entrySet())
   {
      keys[nextFirstPos] = entry.getKey();
      values[nextFirstPos++] = entry.getValue() / norm;
   }

   ++index; 
}

然后将数组和向量的索引传递给比较函数。

于 2013-06-26T20:21:55.123 回答
0

您可以查看项目 simbase https://github.com/guokr/simbase,它是一个向量相似度 nosql 数据库。

Simbase 使用以下概念:

  • 向量集:一组向量
  • 基:向量的基,一个向量集中的向量具有相同的基
  • 推荐:两个具有相同基的向量集之间的单向二元关系

写入操作在每个基础上在单个线程中处理,并且需要在任意两个向量之间进行比较,因此写入操作的比例为 O(n)。

我们在 i7-cpu Macbook 上对密集向量进行了非最终性能测试,它可以在 0.14 秒内轻松处理 100k 1k 维向量,每次写入操作;如果线性比例可以保持,这意味着 Simbase 可以在 1 秒内处理 700k 个密集向量,每次写入操作。

于 2014-06-12T15:18:31.097 回答
-2

我可以清楚地看到至少一个地方,你只是在浪费 CPU 周期:

for (int i = 0; i < keys.length; i++) {
    if (hashMap2.containsKey(keys[i])) {
         ans += hashMap.get(keys[i]) * hashMap2.get(keys[i]);
    }
}

float hashLength = 0;
for (int i = 0; i < keys.length; i++) {
     hashLength += (hashMap.get(keys[i]) * hashMap.get(keys[i]));
}

在这里,您在相同的 2 个 hashMap 上有 2 个相同边界的循环。你为什么不在一个周期内完成它:

float hashLength = 0;
int hm = 0;
for (int i = 0; i < keys.length; i++) {
    hm = hashMap.get(keys[i])*hashMap2.get(keys[i]);
    hashLength += hm;
    if (hashMap2.containsKey(keys[i])) {
         ans += hm;
    }
}

顺便问一下,使用hashMap有什么特殊原因吗?或者你可以用一些更简单的数组来做?

于 2013-06-26T19:28:25.723 回答