2

我在这里寻找一个基本的伪代码大纲。

我的目标是从头开始编写分类树(我正在学习机器学习并希望获得直觉)。但我的训练数据非常庞大:40000 个示例和 1000 个特征。鉴于所需拆分数量的上限是 2 40000,我不知道如何跟踪所有这些分区数据集。

假设我从完整的数据集开始并进行一次拆分。然后我可以将落在拆分一侧的 20000 个示例保存到数据集中,然后重新运行拆分算法以找到该数据集的贪婪拆分。然后说我一直这样做,沿着树的最左边的树枝分裂了几十次。

当我对所有最左边的分割感到满意时,然后呢?如何存储多达 2 40000个单独的子集?以及在对测试示例进行分类时,如何跟踪我进行的所有拆分?这是对我来说没有意义的代码组织。

4

2 回答 2

2

感谢@natan 的详细回答。

但是,如果我理解正确,您关心的主要问题是如何在每个训练样本通过决策树传播时有效地跟踪它。

这可以很容易地完成。

您所需要的只是一个大小向量,N=40000每个训练样本都有一个条目。该向量将告诉您每个样本在树中的位置。让我们称这个向量assoc

如何使用这个向量?

在我看来,最优雅的方法是创建assoc类型uint32并使用位来编码每个训练样本在树中的传播。

中的每个位assoc(k)代表树的某个深度,如果该位设置为(1),则表示样本k向右,否则表示样本k向左。

如果您决定采用此策略,您会发现以下 Matlab 命令很有用bitgetbitset以及bitshift其他一些位函数。

让我们考虑以下树

       root
      /    \
     a      b
           / \
          c   d

因此,对于所有到节点a的示例,它们的assoc值是00b- 因为它们在根处离开(对应于最低有效位 (LSB) 处的零)。

所有去叶节点c的例子,它们的assoc值是01b- 它们在根处向右走(LSB=1),然后向左转(第 2 位 = 0)。

最后,所有到达叶子节点d的例子,它们的assoc值是11b——它们的分支太右了。

现在,您如何找到通过节点b的所有示例?

这简单!

>> selNodeB = bitand( assoc, 1 );

其 LSB 所在的所有节点1在根处右转并通过节点b

于 2013-01-10T08:05:23.857 回答
1

如果您认为有一种方法可以存储 2^40000 位,那么您还没有意识到这个数字有多大,并且您错了大约 10000 个数量级。查看 Matlab 的classregtree文档。

我从@Amro 的详细答案中复制了(在此处找到):

” 下面是分类树模型的几个常用参数:

  • x:数据矩阵,行是实例,列是预测属性
  • y:列向量,每个实例的类标签
  • categorical : 指定哪些属性是离散类型(相对于连续)
  • method : 是生成分类树还是回归树(取决于类类型)
  • names : 为属性命名
  • prune:启用/禁用减少错误修剪
  • minparent/minleaf:如果要进一步拆分,允许指定节点中的最小实例数
  • nvartosample:用于随机树(考虑每个节点上随机选择的 K 个属性)
  • weights:指定加权实例
  • cost : 指定成本矩阵(各种错误的惩罚)
  • splitcriterion:用于在每次拆分时选择最佳属性的标准。我只熟悉基尼指数,它是信息增益标准的一种变体。
  • priorityprob:明确指定先验类概率,而不是根据训练数据计算

一个完整的例子来说明这个过程:

%# load data
load carsmall

%# construct predicting attributes and target class
vars = {'MPG' 'Cylinders' 'Horsepower' 'Model_Year'};
x = [MPG Cylinders Horsepower Model_Year];
y = strcat(Origin,{});

%# train classification decision tree
t = classregtree(x, y, 'method','classification', 'names',vars, ...
                'categorical', [2 4], 'prune','off');
view(t)

%# test
yPredicted = eval(t, x);
cm = confusionmat(y,yPredicted);           %# confusion matrix
N = sum(cm(:));
err = ( N-sum(diag(cm)) ) / N;             %# testing error

%# prune tree to avoid overfitting
tt = prune(t, 'level',2);
view(tt)

%# predict a new unseen instance
inst = [33 4 78 NaN];
prediction = eval(tt, inst)

树

于 2013-01-10T07:38:20.067 回答