2

我使用大小为 4000x300(4000 个质心,每个质心有 300 个特征)的 k-means 创建了一个码本。使用密码本,然后我想标记一个输入向量(用于稍后进行分箱)。输入向量的大小为 Nx300,其中 N 是我收到的输入实例的总数。

为了计算标签,我为每个输入向量计算最近的质心。为此,我将每个输入向量与所有质心进行比较,并选择距离最小的质心。那么标签就是那个质心的索引。

我当前的 Matlab 代码如下所示:

function labels = assign_labels(centroids, X)
labels = zeros(size(X, 1), 1);

% for each X, calculate the distance from each centroid
for i = 1:size(X, 1)
    % distance of X_i from all j centroids is: sum((X_i - centroid_j)^2)
    % note: we leave off the sqrt as an optimization
    distances = sum(bsxfun(@minus, centroids, X(i, :)) .^ 2, 2);
    [value, label] = min(distances);
    labels(i) = label;
end     

但是,这段代码仍然相当慢(出于我的目的),我希望有一种方法可以进一步优化代码。

一个明显的问题是有一个 for 循环,它是 Matlab 良好性能的祸根。我一直试图想出一种方法来摆脱它,但没有运气(我研究过将 arrayfun 与 bsxfun 结合使用,但还没有让它起作用)。或者,如果有人知道任何其他加快速度的方法,我将不胜感激。

更新

在做了一些搜索之后,我找不到使用 Matlab 的好解决方案,所以我决定查看 Python 的 scikits.learn 包中用于 'euclidean_distance' (缩短)的内容:

 XX = sum(X * X, axis=1)[:, newaxis]
 YY = Y.copy()
 YY **= 2
 YY = sum(YY, axis=1)[newaxis, :]
 distances = XX + YY
 distances -= 2 * dot(X, Y.T)
 distances = maximum(distances, 0)

它使用欧几里得距离 ((xy)^2 -> x^2 + y^2 - 2xy) 的二项式形式,据我所知,它通常运行得更快。我完全未经测试的 Matlab 翻译是:

 XX = sum(data .* data, 2);
 YY = sum(center .^ 2, 2);
 [val, ~] = max(XX + YY - 2*data*center');
4

4 回答 4

4

使用以下函数计算您的距离。你应该看到一个数量级的加速

两个矩阵 A 和 B 将列作为维度,将行作为每个点。A 是您的质心矩阵。B 是您的数据点矩阵。

function D=getSim(A,B)
    Qa=repmat(dot(A,A,2),1,size(B,1));
    Qb=repmat(dot(B,B,2),1,size(A,1));
    D=Qa+Qb'-2*A*B';
于 2011-04-24T11:03:16.487 回答
1

对于真正的矩阵实现,您可以考虑尝试以下方式:

  P2 = kron(centroids, ones(size(X,1),1));
  Q2 = kron(ones(size(centroids,1),1), X);

  distances = reshape(sum((Q2-P2).^2,2), size(X,1), size(centroids,1));

注意 这假设数据组织为 [x1 y1 ...; x2 y2 ...;...]

于 2011-04-22T14:51:40.200 回答
1

您可以通过转换为单元格并使用以下方法对其进行矢量化cellfun

[nRows,nCols]=size(X);
XCell=num2cell(X,2);
dist=reshape(cell2mat(cellfun(@(x)(sum(bsxfun(@minus,centroids,x).^2,2)),XCell,'UniformOutput',false)),nRows,nRows);
[~,labels]=min(dist);

解释:

  • X我们在第二行将每一行分配给它自己的单元格
  • 这部分@(x)(sum(bsxfun(@minus,centroids,x).^2,2))是一个匿名函数,与您的distances=...行相同,使用cell2mat,我们将其应用于X.
  • 然后标签是沿每列的最小行的索引。
于 2011-04-20T23:16:22.630 回答
1

您可以使用比蛮力更有效的最近邻搜索算法。最流行的方法是 Kd-Tree。O(log(n)) 平均查询时间而不是 O(n) 蛮力复杂度。关于 Kd-Trees 的 Maltab 实现,你可以看看这里

于 2013-11-07T13:28:45.647 回答