9

我需要在matlab中计算2个矩阵之间的欧几里得距离。目前我正在使用 bsxfun 并计算距离如下(我附上了一段代码):

for i=1:4754
test_data=fea_test(i,:);
d=sqrt(sum(bsxfun(@minus, test_data, fea_train).^2, 2));
end

fea_test 的大小是 4754x1024 和 fea_train 是 6800x1024 ,使用他的 for 循环导致 for 的执行需要大约 12 分钟,我认为这太高了。有没有办法更快地计算两个矩阵之间的欧几里得距离?

有人告诉我,通过删除不必要的 for 循环,我可以减少执行时间。我也知道 pdist2 可以帮助减少计算时间,但由于我使用的是 matlab 7 版,所以我没有 pdist2 函数。升级不是一种选择。

任何帮助。

问候,

巴维亚

4

3 回答 3

12

这是用于计算欧几里得距离的矢量化实现,它比你所拥有的要快得多(甚至比我机器上的PDIST2快得多):

D = sqrt( bsxfun(@plus,sum(A.^2,2),sum(B.^2,2)') - 2*(A*B') );

它基于以下事实:||u-v||^2 = ||u||^2 + ||v||^2 - 2*u.v


考虑下面两种方法之间的粗略比较:

A = rand(4754,1024);
B = rand(6800,1024);

tic
D = pdist2(A,B,'euclidean');
toc

tic
DD = sqrt( bsxfun(@plus,sum(A.^2,2),sum(B.^2,2)') - 2*(A*B') );
toc

在我运行 R2011b 的 WinXP 笔记本电脑上,我们可以看到 10 倍的时间改进:

Elapsed time is 70.939146 seconds.        %# PDIST2
Elapsed time is 7.879438 seconds.         %# vectorized solution

您应该知道,它不会给出与 PDIST2完全相同的结果,直到最小精度。通过比较结果,您会看到细微的差异(通常接近eps浮点相对精度):

>> max( abs(D(:)-DD(:)) )
ans =
  1.0658e-013

附带说明一下,我收集了大约 10 种不同的实现(有些只是彼此之间的小变化),用于这种距离计算,并一直在比较它们。与其他矢量化解决方案相比,您会惊讶于简单循环的速度有多快(感谢 JIT)......

于 2011-10-14T22:41:08.520 回答
2

fea_test您可以通过重复6800 次和fea_train4754 次的行来完全矢量化计算,如下所示:

rA = size(fea_test,1);
rB = size(fea_train,1);

[I,J]=ndgrid(1:rA,1:rB);

d = zeros(rA,rB);

d(:) = sqrt(sum(fea_test(J(:),:)-fea_train(I(:),:)).^2,2));

但是,这将导致大小为 6800x4754x1024 的中间数组(* 8 字节为双精度),这将占用约 250GB 的 RAM。因此,完全矢量化将不起作用。

但是,您可以通过预分配来减少距离计算的时间,并且在必要之前不计算平方根:

rA = size(fea_test,1);
rB = size(fea_train,1);
d = zeros(rA,rB);

for i = 1:rA
    test_data=fea_test(i,:);
    d(i,:)=sum( (test_data(ones(nB,1),:) -  fea_train).^2, 2))';
end

d = sqrt(d);
于 2011-10-08T13:02:13.877 回答
0

试试这个矢量化版本,它应该非常有效。编辑:刚刚注意到我的答案与@Amro 的相似。

function K = calculateEuclideanDist(P,Q)
% Vectorized method to compute pairwise Euclidean distance
% Returns K(i,j) = sqrt((P(i,:) - Q(j,:))'*(P(i,:) - Q(j,:)))

[nP, d] = size(P);
[nQ, d] = size(Q);

pmag = sum(P .* P, 2);
qmag = sum(Q .* Q, 2);

K = sqrt(ones(nP,1)*qmag' + pmag*ones(1,nQ) - 2*P*Q');

end
于 2012-06-29T16:08:22.293 回答