9

我有一个关于以函数式风格编写递归算法的问题。我将在这里使用 Scala 作为示例,但这个问题适用于任何函数式语言。

我正在对n叉树进行深度优先枚举,其中每个节点都有一个标签和可变数量的子节点。这是一个打印叶节点标签的简单实现。

case class Node[T](label:T, ns:Node[T]*)
def dfs[T](r:Node[T]):Seq[T] = {
    if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n)) yield c
}
val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r) // returns Seq[Symbol] = ArrayBuffer('d, 'f, 'c)

现在说,有时我希望能够通过抛出异常来放弃解析超大树。这在功能语言中可能吗?具体来说,这可能不使用可变状态吗?这似乎取决于您所说的“超大”是什么意思。这是该算法的纯函数版本,当它尝试处理深度为 3 或更大的树时会引发异常。

def dfs[T](r:Node[T], d:Int = 0):Seq[T] = {
    require(d < 3)
    if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n, d+1)) yield c
}

但是,如果一棵树因为太宽而不是太深而变得过大怎么办?具体来说,如果我想在第n次递归调用函数时抛出异常,dfs()而不管递归有多深,该怎么办?我能看到如何做到这一点的唯一方法是拥有一个可变计数器,每次调用都会递增。如果没有可变变量,我看不到如何做到这一点。

我是函数式编程的新手,并且一直在假设你可以用可变状态做的任何事情都可以在没有的情况下完成,但我在这里看不到答案。我唯一能想到的就是编写一个版本,dfs()以深度优先的顺序返回树中所有节点的视图。

dfs[T](r:Node[T]):TraversableView[T, Traversable[_]] = ...

然后我可以通过说来施加我的限制dfs(r).take(n),但我不知道如何编写这个函数。在 Python 中,我只是在yield访问节点时创建一个生成器,但我不知道如何在 Scala 中实现相同的效果。(Scala 等效于 Python 样式的yield语句似乎是作为参数传入的访问者函数,但我不知道如何编写其中的一个来生成序列视图。)

编辑接近答案。

Stream这是一个以深度优先顺序返回节点的函数。

def dfs[T](r: Node[T]): Stream[Node[T]] = {
    (r #:: Stream.empty /: r.ns)(_ ++ dfs(_))
}

差不多就是这样。唯一的问题是Streammemoizes所有结果,这很浪费内存。我想要一个可遍历的视图。以下是想法,但没有编译。

def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] = {
    (Traversable(r).view /: r.ns)(_ ++ dfs(_))
}

它为操作员提供了“found TraversableView[Node[T], Traversable[Node[T]]], requiredTraversableView[Node[T], Traversable[_]]错误++。如果我将返回类型更改为TraversableView[Node[T], Traversable[_]],我会在“found”和“required”子句切换时遇到同样的问题。所以有一些魔法类型的变体咒语我没有点燃然而,但这已经很接近了。

4

3 回答 3

7

可以做到:您只需要编写一些代码以您想要的方式实际迭代子项(而不是依赖于for)。

更明确地说,您必须编写代码来遍历子列表并检查“深度”是否超过了您的阈值。这是一些Haskell代码(我真的很抱歉,我不流利地使用 Scala,但这可能很容易音译):

http://ideone.com/O5gvhM

在这段代码中,我基本上将for循环替换为显式递归版本。如果访问节点的数量已经太深(即limit不是正数),这允许我停止递归。当我递归检查下一个孩子时,我减去前一个孩子访问的节点数,dfs并将其设置为下一个孩子的限制。

函数式语言很有趣,但它们是命令式编程的巨大飞跃。它确实让您关注状态的概念,因为当您使用函数时,所有这些都在参数中非常明确。

编辑:再解释一下。

我最终从“仅打印叶节点”(这是 OP 的原始算法)转换为“打印所有节点”。这使我能够通过结果列表的长度访问子调用访问的节点数。如果要坚持使用叶节点,则必须携带已经访问过的节点数:

http://ideone.com/cIQrna

再次编辑为了澄清这个答案,我将所有的 Haskell 代码放在 ideone 上,并且我已经将我的 Haskell 代码音译为 Scala,所以这可以留在这里作为问题的明确答案:

case class Node[T](label:T, children:Seq[Node[T]])

case class TraversalResult[T](num_visited:Int, labels:Seq[T])

def dfs[T](node:Node[T], limit:Int):TraversalResult[T] =
    limit match {
        case 0     => TraversalResult(0, Nil)
        case limit => 
            node.children match {
                case Nil => TraversalResult(1, List(node.label))
                case children => {
                    val result = traverse(node.children, limit - 1)
                    TraversalResult(result.num_visited + 1, result.labels)
                }
            }
    }

def traverse[T](children:Seq[Node[T]], limit:Int):TraversalResult[T] =
    limit match {
        case 0     => TraversalResult(0, Nil)
        case limit =>
            children match {
                case Nil => TraversalResult(0, Nil)
                case first :: rest => {
                    val trav_first = dfs(first, limit)
                    val trav_rest = 
                        traverse(rest, limit - trav_first.num_visited)
                    TraversalResult(
                        trav_first.num_visited + trav_rest.num_visited,
                        trav_first.labels ++ trav_rest.labels
                    )
                }
            }
    }

val n = Node(0, List(
    Node(1, List(Node(2, Nil), Node(3, Nil))),
    Node(4, List(Node(5, List(Node(6, Nil))))),
    Node(7, Nil)
))
for (i <- 1 to 8)
    println(dfs(n, i))

输出:

TraversalResult(1,List())
TraversalResult(2,List())
TraversalResult(3,List(2))
TraversalResult(4,List(2, 3))
TraversalResult(5,List(2, 3))
TraversalResult(6,List(2, 3))
TraversalResult(7,List(2, 3, 6))
TraversalResult(8,List(2, 3, 6, 7))

PS这是我第一次尝试Scala,所以上面可能包含一些可怕的非惯用代码。对不起。

于 2012-11-21T00:40:53.030 回答
4

您可以通过传递索引或获取尾部将广度转换为深度:

def suml(xs: List[Int], total: Int = 0) = xs match {
  case Nil => total
  case x :: rest => suml(rest, total+x)
}

def suma(xs: Array[Int], from: Int = 0, total: Int = 0) = {
  if (from >= xs.length) total
  else suma(xs, from+1, total + xs(from))
}

在后一种情况下,如果你愿意,你已经有一些东西可以限制你的广度;在前者中,只需添加一个width或类似的东西。

于 2012-11-21T00:16:41.610 回答
2

下面实现了对树中节点的惰性深度优先搜索。

import collection.TraversableView
case class Node[T](label: T, ns: Node[T]*)
def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] =
  (Traversable[Node[T]](r).view /: r.ns) {
    (a, b) => (a ++ dfs(b)).asInstanceOf[TraversableView[Node[T], Traversable[Node[T]]]]
  }

这将按深度优先顺序打印所有节点的标签。

val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd, 'e, 'f, 'c)

这做同样的事情,在访问了 3 个节点后退出。

dfs(r).take(3).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd)

如果您只想要叶节点,您可以使用filter,依此类推。

请注意,dfs函数的 fold 子句需要显式asInstanceOf强制转换。请参阅“对 Traversable 视图执行 foldLeft 时 Scala 中的类型变化错误”,以讨论需要这样做的 Scala 类型问题。

于 2012-11-21T19:36:23.347 回答