0

我正在尝试计算一组给定观察值的均值和协方差矩阵。点列表是一个 3-d 数组,第一维表示类别编号,第二维表示观察编号,第三维表示坐标编号。虽然我已经能够计算出平均值,但协方差似乎存在一些问题(现在,我得到一个零矩阵)。如果有人能告诉我如何纠正它,我将不胜感激。

function [ meanEst, covEst, priorProbEst, classMem ] = estimateParams( trainingSet, classList )
%estimateParams estimate all parameters for each class

numRows = size(trainingSet, 1);
numClasses = max(classList.');
%pointList = zeros(numClasses, numRows, 2);
classMem = zeros(numClasses, 1);

for rowCtr = 1:numRows
    curClass = classList(rowCtr, 1);
    classMem(curClass) = classMem(curClass) + 1;
    pointList(curClass, classMem(curClass), 1) = trainingSet(rowCtr, 1);
    pointList(curClass, classMem(curClass), 2) = trainingSet(rowCtr, 2);
end

meanEst      = zeros(numClasses, 2);
covEst       = zeros(numClasses, 2, 2);
priorProbEst = zeros(numClasses, 1);
tot          = zeros(numClasses, 2);

for classCtr = 1:numClasses
    for pointCtr = 1:classMem(classCtr)
        tot(classCtr, 1) = tot(classCtr, 1) + pointList(classCtr, pointCtr, 1);
        tot(classCtr, 2) = tot(classCtr, 2) + pointList(classCtr, pointCtr, 2);
    end
    meanEst(classCtr, 1) = tot(classCtr, 1) / classMem(classCtr);
    meanEst(classCtr, 2) = tot(classCtr, 2) / classMem(classCtr);

    covEst(classCtr) = cov(pointList(classCtr));
    priorProbEst(classCtr) = classMem(classCtr) / numRows;
end
end

感谢您花时间在这上面!

4

1 回答 1

1

我认为您通过引入 3dpointList矩阵使事情复杂化。如果感觉还可以,您可以这样做,但在某个地方存在协方差估计错误。

没有理由将您的数据保存在这样的结构中,因为您有每个观察的类 ID(即,您的每一行trainingSet都有来自相应行的标签classList)。因此,您始终可以使用逻辑索引trainingSet来检索用于估计mean和的数据cov.作为一项规则,N x M = observation x variables任何估计/分类任务的数据矩阵都是一种总是有帮助的约定,并且与许多 MATLAB 函数一致。

例如,下面我创建了一个随机训练集(NxM 矩阵)和标签索引(Nx1 列表中的 K=4 个类)并估计每个类的均值和协方差,分别将结果分配到 aKx22x2xK矩阵中。

nPoints = 200; % training set points
nClass = 4; % number of unique classes

% random training set of size nPoints x 2 (coordinates)
classList = randi(nClass, nPoints, 1);
trainingSet = randn(nPoints, 2);

meanEst = zeros(nClass, 2);
covEst = zeros(2, 2, nClass);
for classID = 1:nClass
    meanEst(classID,:) = mean(trainingSet(classList==classID,:));
    covEst(:,:,classID) = cov(trainingSet(classList==classID,:));
end

作为证明,运行您的代码将产生与mean上述示例相同的结果。

于 2012-09-02T23:22:38.977 回答