-1

我已经在 MATLAB 中实现了自组织映射(SOM)算法。假设每个数据点都表示在二维空间中。问题是我想在训练阶段可视化每个数据点的移动,即我想看看这些点如何移动并最终形成集群,因为算法正在进行中,比如在每个固定持续时间。我相信这可以通过 MATLAB 中的模拟来完成,但我不知道如何将我的 MATLAB 代码合并到可视化中?

4

1 回答 1

2

我开发了一个代码示例,以使用二维中所有可能的数据投影来可视化具有多个维度的聚类数据。这可能不是可视化的最佳想法(为此开发了一些技术,因为 SOM 本身可能用于这种需求),特别是对于更高维度的数字,但是当可能的投影(n-1)!数量不是那么高时,它是一个相当很好的可视化工具。


聚类算法 

由于我需要访问代码以便可以保存每次迭代的聚类均值和聚类标签,因此我使用了Mo Chen的 FEX 提供的快速 kmeans 算法,但我必须对其进行调整才能获得此访问权限。修改后的代码如下:

function [label,m] = litekmeans(X, k)
% Perform k-means clustering.
%   X: d x n data matrix
%   k: number of seeds
% Written by Michael Chen (sth4nth@gmail.com).
n = size(X,2);
last = 0;
iter = 1;
label{iter} = ceil(k*rand(1,n));  % random initialization
checkLabel = label{iter};
m = {};
while any(checkLabel ~= last)
    [u,~,checkLabel] = unique(checkLabel);   % remove empty clusters
    k = length(u);
    E = sparse(1:n,checkLabel,1,n,k,n);  % transform label into indicator matrix
    curM = X*(E*spdiags(1./sum(E,1)',0,k,k));    % compute m of each cluster
    m{iter} = curM;
    last = checkLabel';
    [~,checkLabel] = max(bsxfun(@minus,curM'*X,dot(curM,curM,1)'/2),[],1); % assign samples to the nearest centers
    iter = iter + 1;
    label{iter} = checkLabel;
end
% Get last clusters centers
m{iter} = curM;
% If to remove empty clusters:
%for k=1:iter
%  [~,~,label{k}] = unique(label{k});
%end

动图创作

我还使用了@Amro 的 Matlab 视频教程来创建 gif。

可区分的颜色

我使用了Tim Holy的这个伟大的 FEX来使集群颜色更容易区分。

结果代码

我得到的代码如下。我遇到了一些问题,因为每次迭代的集群数量都会发生变化,这会导致散点图更新删除所有集群中心而不会出现任何错误。由于我没有注意到这一点,我试图用任何我能找到网络的晦涩方法来解决 scatter 函数(顺便说一句,我在这里找到了一个非常好的散点图替代方案),但幸运的是我得到了正在发生的事情今天这个。这是我为它做的代码,你可以随意使用它,调整它,但如果你使用它,请保留我的参考。

function varargout=kmeans_test(data,nClusters,plotOpts,dimLabels,...
  bigXDim,bigYDim,gifName)
%
% [label,m,figH,handles]=kmeans_test(data,nClusters,plotOpts,...
%   dimLabels,bigXDim,bigYDim,gifName)
% Demonstrate kmeans algorithm iterative progress. Inputs are:
%
% -> data (rand(5,100)): the data to use.
%
% -> nClusters (7): number of clusters to use.
%
% -> plotOpts: struct holding the following fields:
%
%   o leftBase: the percentage distance from the left
%
%   o rightBase: the percentage distance from the right
%
%   o bottomBase: the percentage distance from the bottom
%
%   o topBase: the percentage distance from the top
%
%   o FontSize: FontSize for axes labels.
%
%   o widthUsableArea: Total width occupied by axes
%
%   o heigthUsableArea: Total heigth occupied by axes
%
% -> bigXDim (1): the big subplot x dimension
%
% -> bigYDim (2): the big subplot y dimension
%
% -> dimLabels: If you want to specify dimensions labels
%
% -> gifName: gif file name to save
%
% Outputs are:
% 
% -> label: Sample cluster center number for each iteration
%
% -> m: cluster center mean for each iteration
%
% -> figH: figure handle
%
% -> handles: axes handles
%

%
% - Creation Date: Fri, 13 Sep 2013 
% - Last Modified: Mon, 16 Sep 2013 
% - Author(s): 
%   - W.S.Freund <wsfreund_at_gmail_dot_com> 

%
% TODO List (?):
%
%  - Use input parser 
%  - Adapt it to be able to cluster any algorithm function.
%  - Use arrows indicating cluster centers movement before moving them.
%  - Drag and drop small axes to big axes.
%

% Pre-start
if nargin < 7
  gifName = 'kmeansClusterization.gif';
  if nargin < 6
    bigYDim = 2;
    if nargin < 5
      bigXDim = 1;
      if nargin < 4
        nDim = size(data,1);
        maxDigits = numel(num2str(nDim));
        dimLabels = mat2cell(sprintf(['Dim %0' num2str(maxDigits) 'd'],...
          1:nDim),1,zeros(1,nDim)+4+maxDigits);
        if nargin < 3
          plotOpts = struct('leftBase',.05,'rightBase',.02,...
            'bottomBase',.05,'topBase',.02,'FontSize',10,...
            'widthUsableArea',.87,'heigthUsableArea',.87);
          if nargin < 2
            nClusters = 7;
            if nargin < 1
              center1 = [1; 0; 0; 0; 0];
              center2 = [0; 1; 0; 0; 0];
              center3 = [0; 0; 1; 0; 0];
              center4 = [0; 0; 0; 1; 0];
              center5 = [0; 0; 0; 0; 1];
              center6 = [0; 0; 0; 0; 1.5];
              center7 = [0; 0; 0; 1.5; 1];
              data = [...
                      bsxfun(@plus,center1,.5*rand(5,20)) ...
                      bsxfun(@plus,center2,.5*rand(5,20)) ...
                      bsxfun(@plus,center3,.5*rand(5,20)) ...
                      bsxfun(@plus,center4,.5*rand(5,20)) ...
                      bsxfun(@plus,center5,.5*rand(5,20)) ...
                      bsxfun(@plus,center6,.2*rand(5,20)) ...
                      bsxfun(@plus,center7,.2*rand(5,20)) ...
                     ];
            end
          end
        end
      end
    end
  end
end

% NOTE of advice: It seems that Matlab does not test while on
% refreshdata if the dimension of the inputs are equivalent for the
% XData, YData and CData while using scatter. Because of this I wasted
% a lot of time trying to debug what was the problem, trying many
% workaround since my cluster centers would disappear for no reason.

% Draw axes:
nDim = size(data,1);

figH = figure;
set(figH,'Units', 'normalized', 'Position',...
  [0, 0, 1, 1],'Color','w','Name',...
  'k-means example','NumberTitle','Off',...
  'MenuBar','none','Toolbar','figure',...
  'Renderer','zbuffer');

% Create dintinguishable colors matrix:
colorMatrix = distinguishable_colors(nClusters);

% Create axes, deploy them on handles matrix more or less how they
% will be positioned:
[handles,horSpace,vertSpace] = ...
  createAxesGrid(5,5,plotOpts,dimLabels);

% Add main axes
bigSubSize = ceil(nDim/2);
bigSubVec(bigSubSize^2) = 0;
for k = 0:nDim-bigSubSize
  bigSubVec(k*bigSubSize+1:(k+1)*bigSubSize) = ...
    ... %(nDim-bigSubSize+k)*nDim+1:(nDim-bigSubSize+k)*nDim+(nDim-bigSubSize+1);
    bigSubSize+nDim*k:nDim*(k+1);
end

handles(bigSubSize,bigSubSize) = subplot(nDim,nDim,bigSubVec,...
  'FontSize',plotOpts.FontSize,'box','on'); 
bigSubplotH = handles(bigSubSize,bigSubSize);
horSpace(bigSubSize,bigSubSize) = bigSubSize;
vertSpace(bigSubSize,bigSubSize) = bigSubSize;
set(bigSubplotH,'NextPlot','add',...
  'FontSize',plotOpts.FontSize,'box','on',...
  'XAxisLocation','top','YAxisLocation','right');

% Squeeze axes through space to optimize space usage and improve
% visualization capability:
[leftPos,botPos,subplotWidth,subplotHeight]=setCustomPlotArea(...
  handles,plotOpts,horSpace,vertSpace);

pColorAxes = axes('Position',[leftPos(end) botPos(end) ...
  subplotWidth subplotHeight],'Parent',figH);
pcolor([1:nClusters+1;1:nClusters+1])
% image(reshape(colorMatrix,[1 size(colorMatrix)])); % Used image to
% check if the upcoming buggy behaviour would be fixed. I was not
% lucky, though...
colormap(pColorAxes,colorMatrix);
% Change XTick positions to its center:
set(pColorAxes,'XTick',.5:1:nClusters+.5);
set(pColorAxes,'YTick',[]);
% Change its label to cluster number:
set(pColorAxes,'XTickLabel',[nClusters 1:nClusters-1]); % FIXME At
% least on my matlab I have to use this buggy way to set XTickLabel.
% Am I doing something wrong? Since I dunno why this is caused, I just
% change the code so that it looks the way it should look, but this is
% quite strange...
xlabel(pColorAxes,'Clusters Colors','FontSize',plotOpts.FontSize);

% Now iterate throw data and get cluster information:
[label,m]=litekmeans(data,nClusters);

nIters = numel(m)-1;

scatterColors = colorMatrix(label{1},:);

annH = annotation('textbox',[leftPos(1),botPos(1) subplotWidth ...
  subplotHeight],'String',sprintf('Start Conditions'),'EdgeColor',...
  'none','FontSize',18);

% Creates dimData_%d variables for first iteration:
for curDim=1:nDim
  curDimVarName = genvarname(sprintf('dimData_%d',curDim));
  eval([curDimVarName,'= m{1}(curDim,:);']);
end

%   clusterColors will hold the colors for the total number of clusters
% on each iteration:
clusterColors = colorMatrix;

% Draw cluster information for first iteration:
for curColumn=1:nDim
  for curLine=curColumn+1:nDim
    % Big subplot data:
    if curColumn == bigXDim && curLine == bigYDim
      curAxes = handles(bigSubSize,bigSubSize);
      curScatter = scatter(curAxes,data(curColumn,:),...
        data(curLine,:),16,'filled');
      set(curScatter,'CDataSource','scatterColors');
      % Draw cluster centers 
      curColumnVarName = genvarname(sprintf('dimData_%d',curColumn));
      curLineVarName = genvarname(sprintf('dimData_%d',curLine));
      eval(['curScatter=scatter(curAxes,' curColumnVarName ',' ... 
        curLineVarName ',100,colorMatrix,''^'',''filled'');']);
      set(curScatter,'XDataSource',curColumnVarName,'YDataSource',...
        curLineVarName,'CDataSource','clusterColors')
    end
    % Small subplots data:
    curAxes = handles(curLine,curColumn);
    % Draw data:
    curScatter = scatter(curAxes,data(curColumn,:),...
      data(curLine,:),16,'filled');
    set(curScatter,'CDataSource','scatterColors');
    % Draw cluster centers 
    curColumnVarName = genvarname(sprintf('dimData_%d',curColumn));
    curLineVarName = genvarname(sprintf('dimData_%d',curLine));
    eval(['curScatter=scatter(curAxes,' curColumnVarName ',' ... 
      curLineVarName ',100,colorMatrix,''^'',''filled'');']);
    set(curScatter,'XDataSource',curColumnVarName,'YDataSource',...
      curLineVarName,'CDataSource','clusterColors');
    if curLine==nDim
      xlabel(curAxes,dimLabels{curColumn});
      set(curAxes,'XTick',xlim(curAxes));
    end
    if curColumn==1
      ylabel(curAxes,dimLabels{curLine});
      set(curAxes,'YTick',ylim(curAxes));
    end 
  end
end

refreshdata(figH,'caller');

% Preallocate gif frame. From Amro's tutorial here:
% https://stackoverflow.com/a/11054155/1162884
f = getframe(figH);
[f,map] = rgb2ind(f.cdata, 256, 'nodither');
mov = repmat(f, [1 1 1 nIters+4]);

% Add one frame at start conditions:
curFrame = 1;
% Add three frames without movement at start conditions
f = getframe(figH);
mov(:,:,1,curFrame) = rgb2ind(f.cdata, map, 'nodither');

for curIter = 1:nIters
  curFrame = curFrame+1;
  % Change label text
  set(annH,'String',sprintf('Iteration %d',curIter));
  % Update cluster point colors
  scatterColors = colorMatrix(label{curIter+1},:);
  % Update cluster centers:
  for curDim=1:nDim
    curDimVarName = genvarname(sprintf('dimData_%d',curDim));
    eval([curDimVarName,'= m{curIter+1}(curDim,:);']);
  end
  % Update cluster colors:
  nClusterIter = size(m{curIter+1},2);
  clusterColors = colorMatrix(1:nClusterIter,:);
  % Update graphics:
  refreshdata(figH,'caller');
  % Update cluster colors:
  if nClusterIter~=size(m{curIter},2) % If number of cluster
    % of current iteration differs from previous iteration (or start
    % conditions in case we are at first iteration) we redraw colors: 
    pcolor([1:nClusterIter+1;1:nClusterIter+1])
    % image(reshape(colorMatrix,[1 size(colorMatrix)])); % Used image to
    % check if the upcomming buggy behaviour would be fixed. I was not
    % lucky, though...
    colormap(pColorAxes,clusterColors);
    % Change XTick positions to its center:
    set(pColorAxes,'XTick',.5:1:nClusterIter+.5);
    set(pColorAxes,'YTick',[]);
    % Change its label to cluster number:
    set(pColorAxes,'XTickLabel',[nClusterIter 1:nClusterIter-1]); 
    xlabel(pColorAxes,'Clusters Colors','FontSize',plotOpts.FontSize);
  end
  f = getframe(figH);
  mov(:,:,1,curFrame) = rgb2ind(f.cdata, map, 'nodither');
end

set(annH,'String','Convergence Conditions');

for curFrame = nIters+1:nIters+3
  % Add three frames without movement at start conditions
  f = getframe(figH);
  mov(:,:,1,curFrame) = rgb2ind(f.cdata, map, 'nodither');
end

imwrite(mov, map, gifName, 'DelayTime',.5, 'LoopCount',inf)

varargout = cell(1,nargout);

if nargout > 0
  varargout{1} = label;
  if nargout > 1
    varargout{2} = m;
    if nargout > 2
      varargout{3} = figH;
      if nargout > 3
        varargout{4} = handles;
      end
    end
  end
end

end


function [leftPos,botPos,subplotWidth,subplotHeight] = ...
  setCustomPlotArea(handles,plotOpts,horSpace,vertSpace)
%
% -> handles: axes handles
%
% -> plotOpts: struct holding the following fields:
%
%   o leftBase: the percentage distance from the left
%
%   o rightBase: the percentage distance from the right
%
%   o bottomBase: the percentage distance from the bottom
%
%   o topBase: the percentage distance from the top
%
%   o widthUsableArea: Total width occupied by axes
%
%   o heigthUsableArea: Total heigth occupied by axes
%
% -> horSpace: the axes units size (integers only) that current axes
% should occupy in the horizontal (considering that other occupied
% axes handles are empty)
%
% -> vertSpace: the axes units size (integers only) that current axes
% should occupy in the vertical (considering that other occupied
% axes handles are empty)
%

nHorSubPlot =  size(handles,1);
nVertSubPlot = size(handles,2);

if nargin < 4
  horSpace(nHorSubPlot,nVertSubPlot) = 0;
  horSpace = horSpace+1;
  if nargin < 3
    vertSpace(nHorSubPlot,nVertSubPlot) = 0;
    vertSpace = vertSpace+1;
  end
end

subplotWidth = plotOpts.widthUsableArea/nHorSubPlot;
subplotHeight = plotOpts.heigthUsableArea/nVertSubPlot;

totalWidth = (1-plotOpts.rightBase) - plotOpts.leftBase;
totalHeight = (1-plotOpts.topBase) - plotOpts.bottomBase;

gapHeigthSpace = (totalHeight - ...
  plotOpts.heigthUsableArea)/(nVertSubPlot);
gapWidthSpace = (totalWidth - ...
  plotOpts.widthUsableArea)/(nHorSubPlot);

botPos(nVertSubPlot) = plotOpts.bottomBase + gapWidthSpace/2;
leftPos(1) = plotOpts.leftBase + gapHeigthSpace/2;

botPos(nVertSubPlot-1:-1:1) = botPos(nVertSubPlot) + (subplotHeight +...
  gapHeigthSpace)*(1:nVertSubPlot-1);
leftPos(2:nHorSubPlot) = leftPos(1) + (subplotWidth +...
  gapWidthSpace)*(1:nHorSubPlot-1);

for curLine=1:nHorSubPlot
  for curColumn=1:nVertSubPlot
    if handles(curLine,curColumn)
      set(handles(curLine,curColumn),'Position',[leftPos(curColumn)...
        botPos(curLine) horSpace(curLine,curColumn)*subplotWidth ...             
        vertSpace(curLine,curColumn)*subplotHeight]);                     
    end
  end                                                         
end                                                           

end


function [handles,horSpace,vertSpace] = ...
  createAxesGrid(nLines,nColumns,plotOpts,dimLabels)

handles = zeros(nLines,nColumns);

% Those hold the axes size units:
horSpace(nLines,nColumns) = 0;
vertSpace(nLines,nColumns) = 0;

for curColumn=1:nColumns
  for curLine=curColumn+1:nLines
    handles(curLine,curColumn) = subplot(nLines,...
      nColumns,curColumn+(curLine-1)*nColumns);
    horSpace(curLine,curColumn) = 1;
    vertSpace(curLine,curColumn) = 1;
    curAxes = handles(curLine,curColumn);
    if feature('UseHG2')
      colormap(handle(curAxes),colorMatrix);
    end
    set(curAxes,'NextPlot','add',...
      'FontSize',plotOpts.FontSize,'box','on'); 
    if curLine==nLines
      xlabel(curAxes,dimLabels{curColumn});
    else
      set(curAxes,'XTick',[]);
    end
    if curColumn==1
      ylabel(curAxes,dimLabels{curLine});
    else
      set(curAxes,'YTick',[]);
    end
  end
end
end

例子

这是一个使用 5 个维度的示例,使用代码:

center1 = [1; 0; 0; 0; 0];
center2 = [0; 1; 0; 0; 0];
center3 = [0; 0; 1; 0; 0];
center4 = [0; 0; 0; 1; 0];
center5 = [0; 0; 0; 0; 1];
center6 = [0; 0; 0; 0; 1.5];
center7 = [0; 0; 0; 1.5; 1];
data = [...
        bsxfun(@plus,center1,.5*rand(5,20)) ...
        bsxfun(@plus,center2,.5*rand(5,20)) ...
        bsxfun(@plus,center3,.5*rand(5,20)) ...
        bsxfun(@plus,center4,.5*rand(5,20)) ...
        bsxfun(@plus,center5,.5*rand(5,20)) ...
        bsxfun(@plus,center6,.2*rand(5,20)) ...
        bsxfun(@plus,center7,.2*rand(5,20)) ...
       ];
[label,m,figH,handles]=kmeans_test(data,20);

在此处输入图像描述

于 2013-09-16T08:54:22.397 回答