我在下面提供了一个完整的示例,该示例说明了如何使用 BNT 工具箱构建一个朴素的贝叶斯网络。我正在使用汽车数据集的一个子集。它包含离散和连续属性。
为方便起见,我使用了几个需要统计工具箱的函数。
我们首先准备数据集:
%# load dataset
D = load('carsmall');
%# keep only features of interest
D = rmfield(D, {'Mfg','Horsepower','Displacement','Model'});
%# filter the rows to keep only two classes
idx = ismember(D.Origin, {'USA' 'Japan'});
D = structfun(@(x)x(idx,:), D, 'UniformOutput',false);
numInst = sum(idx);
%# replace missing values with mean
D.MPG(isnan(D.MPG)) = nanmean(D.MPG);
%# convert discrete attributes to numeric indices 1:mx
[D.Origin,~,gnOrigin] = grp2idx( cellstr(D.Origin) );
[D.Cylinders,~,gnCylinders] = grp2idx( D.Cylinders );
[D.Model_Year,~,gnModel_Year] = grp2idx( D.Model_Year );
接下来我们建立我们的图形模型:
%# info about the nodes
nodeNames = fieldnames(D);
numNodes = numel(nodeNames);
node = [nodeNames num2cell((1:numNodes)')]';
node = struct(node{:});
dNodes = [node.Origin node.Cylinders node.Model_Year];
cNodes = [node.MPG node.Weight node.Acceleration];
depNodes = [node.MPG node.Cylinders node.Weight ...
node.Acceleration node.Model_Year];
vals = cell(1,numNodes);
vals(dNodes) = cellfun(@(f) unique(D.(f)), nodeNames(dNodes), 'Uniform',false);
nodeSize = ones(1,numNodes);
nodeSize(dNodes) = cellfun(@numel, vals(dNodes));
%# DAG
dag = false(numNodes);
dag(node.Origin, depNodes) = true;
%# create naive bayes net
bnet = mk_bnet(dag, nodeSize, 'discrete',dNodes, 'names',nodeNames, ...
'observed',depNodes);
for i=1:numel(dNodes)
name = nodeNames{dNodes(i)};
bnet.CPD{dNodes(i)} = tabular_CPD(bnet, node.(name), ...
'prior_type','dirichlet');
end
for i=1:numel(cNodes)
name = nodeNames{cNodes(i)};
bnet.CPD{cNodes(i)} = gaussian_CPD(bnet, node.(name));
end
%# visualize the graph
[~,~,h] = draw_graph(bnet.dag, nodeNames);
hTxt = h(:,1); hNodes = h(:,2);
set(hTxt(node.Origin), 'FontWeight','bold', 'Interpreter','none')
set(hNodes(node.Origin), 'FaceColor','g')
set(hTxt(depNodes), 'Color','k', 'Interpreter','none')
set(hNodes(depNodes), 'FaceColor','y')
现在我们将数据分成训练/测试:
%# build samples as cellarray
data = num2cell(cell2mat(struct2cell(D)')');
%# split train/test: 1/3 for testing, 2/3 for training
cv = cvpartition(D.Origin, 'HoldOut',1/3);
trainData = data(:,cv.training);
testData = data(:,cv.test);
testData(1,:) = {[]}; %# remove class
最后我们从训练集中学习参数,并预测测试数据的类别:
%# training
bnet = learn_params(bnet, trainData);
%# testing
prob = zeros(nodeSize(node.Origin), sum(cv.test));
engine = jtree_inf_engine(bnet); %# Inference engine
for i=1:size(testData,2)
[engine,loglik] = enter_evidence(engine, testData(:,i));
marg = marginal_nodes(engine, node.Origin);
prob(:,i) = marg.T;
end
[~,pred] = max(prob);
actual = D.Origin(cv.test)';
%# confusion matrix
predInd = full(sparse(1:numel(pred),pred,1));
actualInd = full(sparse(1:numel(actual),actual,1));
conffig(predInd, actualInd); %# confmat
%# ROC plot and AUC
figure
[~,~,auc] = plotROC(max(prob), pred==actual, 'b')
title(sprintf('Area Under the Curve = %g',auc))
set(findobj(gca, 'type','line'), 'LineWidth',2)
结果:
我们可以提取每个节点的 CPT 和均值/西格玛:
cellfun(@(x)dispcpt(struct(x).CPT), bnet.CPD(dNodes), 'Uniform',false)
celldisp(cellfun(@(x)struct(x).mean, bnet.CPD(cNodes), 'Uniform',false))
celldisp(cellfun(@(x)struct(x).cov, bnet.CPD(cNodes), 'Uniform',false))