我有点不好意思承认这一点,但我似乎被一个简单的编程问题难住了。我正在构建一个决策树实现,并且一直在使用递归来获取标记样本的列表,递归地将列表分成两半,然后将其变成一棵树。
不幸的是,对于深树,我遇到了堆栈溢出错误(哈!),所以我的第一个想法是使用延续将其变成尾递归。不幸的是,Scala 不支持这种 TCO,因此唯一的解决方案是使用蹦床。蹦床似乎有点低效,我希望有一些简单的基于堆栈的命令式解决方案来解决这个问题,但我很难找到它。
递归版本看起来有点像(简化):
private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = {
if (shouldStop(samples)) {
DTLeaf(makeProportions(samples))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
DTBranch(
trainTree(statsWithFeature, usedFeatures + featureIdx),
trainTree(statsWithoutFeature, usedFeatures + featureIdx),
featureIdx)
}
}
所以基本上我会根据数据的某些特征将列表递归地细分为两部分,并传递一个已使用特征的列表,所以我不会重复——这一切都在“getSplittingFeature”函数中处理,所以我们可以忽略它。代码真的很简单!尽管如此,我仍然无法找出一个基于堆栈的解决方案,它不仅使用闭包而且有效地成为蹦床。我知道我们至少必须在堆栈中保留小的“框架”参数,但我想避免闭包调用。
我知道我应该在递归解决方案中显式地写出调用堆栈和程序计数器为我处理的内容,但是如果没有继续,我很难做到这一点。在这一点上,它甚至与效率无关,我只是好奇。所以请不要提醒我,过早的优化是万恶之源,基于蹦床的解决方案可能会工作得很好。我知道它可能会——这基本上是一个谜。
谁能告诉我这种事情的规范的基于 while-loop-and-stack 的解决方案是什么?
更新:基于 Thipor Kong 的优秀解决方案,我编写了一个基于 while-loops/stacks/hashtable 的算法实现,它应该是递归版本的直接翻译。这正是我一直在寻找的:
最后更新:我使用了顺序整数索引,以及将所有内容放回数组而不是映射以提高性能,添加了 maxDepth 支持,最后有了一个与递归版本具有相同性能的解决方案(不确定内存使用情况,但我会少猜):
private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = {
// Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit
type DenseIntMap[T] = ArrayBuffer[T]
def updateIntMap[@specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = {
if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) }
ab.update(idx, item)
}
var currentChildId = 0 // get childIdx or create one if it's not there already
def child(childMap: DenseIntMap[Int], heapIdx: Int) =
if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx)
else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId }
// go down
val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx
val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx
val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx
val nodes = new DenseIntMap[DTree]() // heapIdx -> node
while (!todo.isEmpty) {
val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop()
if (shouldStop(samples) || maxDepth == 0) {
updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples)))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx)))
todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx)))
branches.push((heapIdx, featureIdx))
}
}
// go up
while (!branches.isEmpty) {
val (heapIdx, featureIdx) = branches.pop()
updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx))
}
nodes(0)
}