0

我为冗长的标题道歉。请考虑以下代码,它可能会出现在k-means集群中。

rng(1)
num_samples = 10;
samples = randi(100, num_samples, 3);
cluster_centroids = randi(16,3);
cluster_indices = zeros(num_samples,1);

for index = 1:num_samples
    distances = sqrt(sum((samples(index) - cluster_centroids).^2, 2));
    cluster = find(distances == min(distances), 1)
    cluster_indices(index) = cluster;
end

是否有某种方法可以对其进行矢量化并删除 for 循环,以便我们有效地处理所有样本(它们是三个整数的元组)?

4

1 回答 1

2

我在底部列出的现有代码存在几个问题,但首先是矢量化答案。答案的核心类似于我过去的解决方案。我对您的变量进行了一些修改,并添加了更多解释:

[nPoints,nDims] = size(samples);
k = 3; % ? size(cluster_centroids,1)

% Calculate all high-dimensional distances at once
% (NxDx1 - 1xDxK => NxDxK)
kdiffs = bsxfun(@minus,samples,permute(cluster_centroids,[3 2 1]));
distances = sum(kdiffs.^2,2); % no need to do sqrt
distances = squeeze(distances); % Nx1xK => NxK

是矢量化的distances重要值。其余的相当琐碎:

% Find closest cluster center for each point
[~,cluster_indices] = min(distances,[],2); % Nx1

然后,您将需要为以下迭代更新集群中心:

cluster_centroids_new = zeros(k,nDims);
for i=1:k,
    indk = cluster_indices==i;
    clustersizes(i) = nnz(indk);
    cluster_centroids_new(i,:) = mean(samples(indk,:))';
end

要对此代码添加一点讨论,请注意cluster_centroids_new每个集群都有一行,但如果集群没有成员,那么该行将是NaNs。

有问题的代码问题:

  1. cluster_indices被初始化为矩阵而不是向量。修复:(cluster_indices = zeros(num_samples,1);如 horchler 所述)。
  2. 您在计算中没有samples正确索引。distances更改samples(index)samples(index,:)提取样本。
  3. 每个 ND 样本之间的距离cluster_centroids需要sum((repmat(samples(index,:),size(cluster_centroids,1),1) - cluster_centroids).^2, 2)计算一个样本到所有k=size(cluster_centroids,1)簇的距离。
  4. 不一定是错误,但请确保您的意思是cluster_centroids = randi(16,3);它给出了一个 3x3 矩阵,数字从 1 到 16。
于 2013-11-12T00:49:14.690 回答