10

我编写了一个简单的测试平台来测量三种阶乘实现的性能:基于循环、非尾递归和尾递归。

令我惊讶的是,性能最差的是循环的(“while”被认为更有效,所以我提供了两者) ,其成本几乎是尾递归替代方案的两倍。

*回答:修复循环实现,避免= 运算符,由于其内部«循环»变得最快,BigInt 的表现优于预期

我经历过的另一个 «woodoo» 行为是 StackOverflow 异常,在非尾递归实现的情况下,对于相同的输入没有系统地抛出该异常。我可以通过逐步调用具有越来越大的值的函数来规避 StackOverlow ......我觉得很疯狂 :)答:JVM 需要在启动期间收敛,然后行为是连贯和系统的

这是代码:

final object Factorial {
  type Out = BigInt

  def calculateByRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    n match {
      case _ if n == 1 => return 1
      case _ => return n * calculateByRecursion(n-1)
    }
  }

  def calculateByForLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    for (i <- 1 to n)
      accumulator = i * accumulator
    accumulator
  }

  def calculateByWhileLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    var i = 1
    while (i <= n) {
      accumulator = i * accumulator
      i += 1
    }
    accumulator
  }

  def calculateByTailRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(n: Int, acc: Out): Out = n match {
      case _ if n == 1 => acc
      case _ => fac(n-1, n * acc)
    }

    fac(n, 1)
  }

  def calculateByTailRecursionUpward(n: Int): Out = {
    require(n>0, "n must be positive")

    @tailrec def fac(i: Int, acc: Out): Out = n match {
      case _ if i == n => n * acc
      case _ => fac(i+1, i * acc)
    }

    fac(1, 1)
  }

  def comparePerformance(n: Int) {
    def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = false) =
      showOutput match {
        case true => printf("%s returned %s in %d ms\n", msg, data._2.toString, data._1)
        case false => printf("%s in %d ms\n", msg, data._1)
    }
    def measure[A](f:()=>A): (Long, A) = {
      val start = System.currentTimeMillis
      val o = f()
      (System.currentTimeMillis - start, o)
    }
    showOutput ("By for loop", measure(()=>calculateByForLoop(n)))
    showOutput ("By while loop", measure(()=>calculateByWhileLoop(n)))
    showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n)))
    showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n)))
    showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n)))
  }
}

以下是 sbt 控制台的一些输出(在 «while» 实现之前)

scala> example.Factorial.comparePerformance(10000)
By loop in 3 ns
By non-tail recursion in >>>>> StackOverflow!!!!!… see later!!!
........

scala> example.Factorial.comparePerformance(1000)
By loop in 3 ms
By non-tail recursion in 1 ms
By tail recursion in 4 ms

scala> example.Factorial.comparePerformance(5000)
By loop in 105 ms
By non-tail recursion in 27 ms
By tail recursion in 34 ms

scala> example.Factorial.comparePerformance(10000)
By loop in 236 ms
By non-tail recursion in 106 ms     >>>> Now works!!!
By tail recursion in 127 ms

scala> example.Factorial.comparePerformance(20000)
By loop in 977 ms
By non-tail recursion in 495 ms
By tail recursion in 564 ms

scala> example.Factorial.comparePerformance(30000)
By loop in 2285 ms
By non-tail recursion in 1183 ms
By tail recursion in 1281 ms

以下是 sbt 控制台的一些输出(在 «while» 实现之后)

scala> example.Factorial.comparePerformance(10000)
By for loop in 252 ms
By while loop in 246 ms
By non-tail recursion in 130 ms
By tail recursion in 136 ns

scala> example.Factorial.comparePerformance(20000)
By for loop in 984 ms
By while loop in 1091 ms
By non-tail recursion in 508 ms
By tail recursion in 560 ms

接下来是 sbt 控制台的一些输出(在«向上»尾递归实施之后)世界恢复了理智

scala> example.Factorial.comparePerformance(10000)
By for loop in 259 ms
By while loop in 229 ms
By non-tail recursion in 114 ms
By tail recursion in 119 ms
By tail recursion upward in 105 ms

scala> example.Factorial.comparePerformance(20000)
By for loop in 1053 ms
By while loop in 957 ms
By non-tail recursion in 513 ms
By tail recursion in 565 ms
By tail recursion upward in 470 ms

以下是在 «loops» 中修复 BigInt 乘法后 sbt 控制台的一些输出:世界完全正常

    scala> example.Factorial.comparePerformance(20000)
By for loop in 498 ms
By while loop in 502 ms
By non-tail recursion in 521 ms
By tail recursion in 611 ms
By tail recursion upward in 503 ms

BigInt 开销和我的愚蠢实现掩盖了预期的行为。

PS.: 最后我应该把这篇文章重新命名为 «A lernt course on BigInts»

4

3 回答 3

12

For 循环实际上并不完全是循环。它们用于范围内的理解。如果你真的想要一个循环,你需要使用while. (实际上,我认为BigInt这里的乘法已经足够重量级了,所以没关系。但是如果你乘以Ints,你会注意到。)

此外,您使用BigInt. 你的越大BigInt,你的乘法就越慢。因此,您的非尾循环计数增加,而您的尾递归循环计数减少,这意味着后者有更多的大数要相乘。

如果你解决了这两个问题,你会发现恢复了理智:循环和尾递归的速度相同,常规递归和for更慢。(如果 JVM 优化使其等效,正则递归可能不会变慢)

(此外,堆栈溢出修复可能是因为 JVM 开始内联,并且可能使调用本身进行尾递归,或者将循环展开足够远以便您不再溢出。)

最后,你用 for 和 while 得到了很差的结果,因为你是在右边而不是左边与小数相乘。事实证明,Java 的 BigInt 与左侧的较小数字相乘的速度更快。

于 2013-02-28T15:07:13.800 回答
1

我知道每个人都已经回答了这个问题,但我想我可能会添加一个优化:如果将模式匹配转换为简单的 if 语句,它可以加快尾递归。

final object Factorial {
  type Out = BigInt

  def calculateByRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    n match {
      case _ if n == 1 => return 1
      case _ => return n * calculateByRecursion(n-1)
    }
  }

  def calculateByForLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var accumulator: Out = 1
    for (i <- 1 to n)
      accumulator = i * accumulator
    accumulator
  }

  def calculateByWhileLoop(n: Int): Out = {
    require(n>0, "n must be positive")

    var acc: Out = 1
    var i = 1
    while (i <= n) {
      acc = i * acc
      i += 1
    }
    acc
  }

  def calculateByTailRecursion(n: Int): Out = {
    require(n>0, "n must be positive")

    @annotation.tailrec
    def fac(n: Int, acc: Out): Out = if (n==1) acc else fac(n-1, n*acc)

    fac(n, 1)
  }

  def calculateByTailRecursionUpward(n: Int): Out = {
    require(n>0, "n must be positive")

    @annotation.tailrec
    def fac(i: Int, acc: Out): Out = if (i == n) n*acc else fac(i+1, i*acc)

    fac(1, 1)
  }

  def attempt(f: ()=>Unit): Boolean = {
    try {
        f()
        true
    } catch {
        case _: Throwable =>
            println(" <<<<< Failed...")
            false
    }
  }

  def comparePerformance(n: Int) {
    def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = true) =
      showOutput match {
        case true =>
            val res = data._2.toString
            val pref = res.substring(0,5)
            val midd = res.substring((res.length-5)/ 2, (res.length-5)/ 2 + 10)
            val suff = res.substring(res.length-5)
            printf("%s returned %s in %d ms\n", msg, s"$pref...$midd...$suff" , data._1)
        case false => 
            printf("%s in %d ms\n", msg, data._1)
    }
    def measure[A](f:()=>A): (Long, A) = {
      val start = System.currentTimeMillis
      val o = f()
      (System.currentTimeMillis - start, o)
    }
    attempt(() => showOutput ("By for loop", measure(()=>calculateByForLoop(n))))
    attempt(() => showOutput ("By while loop", measure(()=>calculateByWhileLoop(n))))
    attempt(() => showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n))))
    attempt(() => showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n))))
    attempt(() => showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n))))
  }
}

我的结果:

scala> Factorial.comparePerformance(20000)
By for loop returned 18192...5708616582...00000 in 179 ms
By while loop returned 18192...5708616582...00000 in 159 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 18192...5708616582...00000 in 169 ms
By tail recursion upward returned 18192...5708616582...00000 in 174 ms
By for loop returned 18192...5708616582...00000 in 212 ms
By while loop returned 18192...5708616582...00000 in 156 ms
By non-tail recursion returned 18192...5708616582...00000 in 155 ms
By tail recursion returned 18192...5708616582...00000 in 166 ms
By tail recursion upward returned 18192...5708616582...00000 in 137 ms
scala> Factorial.comparePerformance(200000)
By for loop returned 14202...0169293868...00000 in 17467 ms
By while loop returned 14202...0169293868...00000 in 17303 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 14202...0169293868...00000 in 18477 ms
By tail recursion upward returned 14202...0169293868...00000 in 17188 ms
于 2019-07-08T17:38:34.193 回答
1

Scala 静态方法factorial(n)(使用 scala 2.12.x,java-8 编码):

object Factorial {

  /*
   * For large N, it throws a stack overflow
   */
  def recursive(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      n * recursive(n - 1)
    }
  }

  /*
   * A tail recursive method is compiled to avoid stack overflow
   */
  @scala.annotation.tailrec
  def recursiveTail(n:BigInt, acc:BigInt = 1): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      acc
    } else {
      recursiveTail(n - 1, n * acc)
    }
  }

  /*
   * A while loop
   */
  def loop(n:BigInt): BigInt = {
    if(n < 0) {
      throw new ArithmeticException
    } else if(n <= 1) {
      1
    } else {
      var acc = 1
      var idx = 1
      while(idx <= n) {
        acc = idx * acc
        idx += 1
      }
      acc
    }
  }

}

眼镜:

class FactorialSpecs extends SpecHelper {

  private val smallInt = 10
  private val largeInt = 10000

  describe("Factorial.recursive") {
    it("return 1 for 0") {
      assert(Factorial.recursive(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursive(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursive(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursive(smallInt) == 3628800)
    }
    it("throws StackOverflow for large inputs") {
      intercept[java.lang.StackOverflowError] {
        Factorial.recursive(Int.MaxValue)
      }
    }
  }

  describe("Factorial.recursiveTail") {
    it("return 1 for 0") {
      assert(Factorial.recursiveTail(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.recursiveTail(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.recursiveTail(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.recursiveTail(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.recursiveTail(largeInt).isInstanceOf[BigInt])
    }
  }

  describe("Factorial.loop") {
    it("return 1 for 0") {
      assert(Factorial.loop(0) == 1)
    }
    it("return 1 for 1") {
      assert(Factorial.loop(1) == 1)
    }
    it("return 2 for 2") {
      assert(Factorial.loop(2) == 2)
    }
    it("returns a result, for small inputs") {
      assert(Factorial.loop(smallInt) == 3628800)
    }
    it("returns a result, for large inputs") {
      assert(Factorial.loop(largeInt).isInstanceOf[BigInt])
    }
  }
}

基准:

import org.scalameter.api._

class BenchmarkFactorials extends Bench.OfflineReport {

  val gen: Gen[Int] = Gen.range("N")(1, 1000, 100) // scalastyle:ignore

  performance of "Factorial" in {
    measure method "loop" in {
      using(gen) in {
        n => Factorial.loop(n)
      }
    }
    measure method "recursive" in {
      using(gen) in {
        n => Factorial.recursive(n)
      }
    }
    measure method "recursiveTail" in {
      using(gen) in {
        n => Factorial.recursiveTail(n)
      }
    }
  }

}

基准测试结果(循环更快):

[info] Test group: Factorial.loop
[info] - Factorial.loop.Test-9 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.01 ms, ci = <0.00 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.01 ms, ci = <0.01 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.02 ms, ci = <0.02 ms, 0.02 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.03 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.03 ms, ci = <0.03 ms, 0.04 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.04 ms, ci = <0.04 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.05 ms, ci = <0.05 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.06 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.06 ms, ci = <0.05 ms, 0.07 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursive
[info] - Factorial.recursive.Test-10 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.05 ms, ci = <0.01 ms, 0.09 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.03 ms, ci = <0.02 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.07 ms, ci = <0.00 ms, 0.13 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.09 ms, ci = <0.01 ms, 0.18 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.10 ms, ci = <0.03 ms, 0.17 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.11 ms, ci = <0.08 ms, 0.15 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.13 ms, ci = <0.11 ms, 0.14 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.16 ms, ci = <0.13 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.21 ms, ci = <0.15 ms, 0.27 ms>, significance = 1.0E-10)

[info] Test group: Factorial.recursiveTail
[info] - Factorial.recursiveTail.Test-11 measurements:
[info]   - at N -> 1: passed
[info]     (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info]   - at N -> 101: passed
[info]     (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info]   - at N -> 201: passed
[info]     (mean = 0.12 ms, ci = <0.05 ms, 0.20 ms>, significance = 1.0E-10)
[info]   - at N -> 301: passed
[info]     (mean = 0.16 ms, ci = <-0.03 ms, 0.34 ms>, significance = 1.0E-10)
[info]   - at N -> 401: passed
[info]     (mean = 0.12 ms, ci = <0.09 ms, 0.16 ms>, significance = 1.0E-10)
[info]   - at N -> 501: passed
[info]     (mean = 0.17 ms, ci = <0.15 ms, 0.19 ms>, significance = 1.0E-10)
[info]   - at N -> 601: passed
[info]     (mean = 0.23 ms, ci = <0.19 ms, 0.26 ms>, significance = 1.0E-10)
[info]   - at N -> 701: passed
[info]     (mean = 0.25 ms, ci = <0.18 ms, 0.32 ms>, significance = 1.0E-10)
[info]   - at N -> 801: passed
[info]     (mean = 0.28 ms, ci = <0.21 ms, 0.36 ms>, significance = 1.0E-10)
[info]   - at N -> 901: passed
[info]     (mean = 0.32 ms, ci = <0.17 ms, 0.46 ms>, significance = 1.0E-10)
于 2018-01-08T01:14:42.140 回答