10

我在向一位朋友解释说,我希望 Scala 中的非尾递归函数比尾递归函数慢,所以我决定验证一下。我以两种方式编写了一个很好的旧阶乘函数并尝试比较结果。这是代码:

def main(args: Array[String]): Unit = {
  val N = 2000 // not too much or else stackoverflows
  var spent1: Long = 0
  var spent2: Long = 0
  for ( i <- 1 to 100 ) { // repeat to average the results
    val t0 = System.nanoTime
    factorial(N)
    val t1 = System.nanoTime
    tailRecFact(N)
    val t2 = System.nanoTime
    spent1 += t1 - t0
    spent2 += t2 - t1
  }
  println(spent1/1000000f) // get milliseconds
  println(spent2/1000000f)
}

@tailrec
def tailRecFact(n: BigInt, s: BigInt = 1): BigInt = if (n == 1) s else tailRecFact(n - 1, s * n)

def factorial(n: BigInt): BigInt = if (n == 1) 1 else n * factorial(n - 1)

结果让我很困惑,我得到了这样的输出:

578.2985

870.22125

这意味着非尾递归函数比尾递归函数快 30%,并且操作数相同!

什么可以解释这些结果?

4

2 回答 2

10

它实际上不是你首先要看的地方。原因在于你的尾递归方法,你正在用它的乘法做更多的工作。尝试在递归调用中交换参数 n 和 s 的顺序,它会变平。

def tailRecFact(n: BigInt, s: BigInt): BigInt = if (n == 1) s else tailRecFact(n - 1, n * s)

此外,此示例中的大部分时间都用于 BigInt 操作,这使递归调用的时间相形见绌。如果我们将这些转换为 Ints(编译为 Java 原语),那么您可以看到尾递归(goto)与方法调用的比较。

object Test extends App {

  val N = 2000

  val t0 = System.nanoTime()
  for ( i <- 1 to 1000 ) {
    factorial(N)
  }
  val t1 = System.nanoTime
  for ( i <- 1 to 1000 ) {
    tailRecFact(N, 1)
  }
  val t2 = System.nanoTime

  println((t1 - t0) / 1000000f) // get milliseconds
  println((t2 - t1) / 1000000f)

  def factorial(n: Int): Int = if (n == 1) 1 else n * factorial(n - 1)

  @tailrec
  final def tailRecFact(n: Int, s: Int): Int = if (n == 1) s else tailRecFact(n - 1, s * n)
}

95.16733
3.987605

出于兴趣,反编译的输出

  public final scala.math.BigInt tailRecFact(scala.math.BigInt, scala.math.BigInt);
    Code:
       0: aload_1       
       1: iconst_1      
       2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
       5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
       8: ifeq          13
      11: aload_2       
      12: areturn       
      13: aload_1       
      14: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      17: iconst_1      
      18: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      21: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
      24: aload_1       
      25: aload_2       
      26: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
      29: astore_2      
      30: astore_1      
      31: goto          0

  public scala.math.BigInt factorial(scala.math.BigInt);
    Code:
       0: aload_1       
       1: iconst_1      
       2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
       5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
       8: ifeq          21
      11: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      14: iconst_1      
      15: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      18: goto          40
      21: aload_1       
      22: aload_0       
      23: aload_1       
      24: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      27: iconst_1      
      28: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      31: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
      34: invokevirtual #47                 // Method factorial:(Lscala/math/BigInt;)Lscala/math/BigInt;
      37: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
      40: areturn   
于 2013-10-09T09:40:52.197 回答
9

除了@monkjack 显示的问题(即乘以小 * 大比大 * 小快,这确实占了更大的差异),您的算法在每种情况下都是不同的,因此它们并不是真正可比的。

在尾递归版本中,您将大到小相乘:

n * n-1 * n-2 * ... * 2 * 1

在非尾递归版本中,您将小到大相乘:

n * (n-1 * (n-2 * (... * (2 * 1))))

如果您更改尾递归版本,使其从小到大相乘:

def tailRecFact2(n: BigInt) = {
  def loop(x: BigInt, out: BigInt): BigInt =
    if (x > n) out else loop(x + 1, x * out)
  loop(1, 1)
}

那么尾递归比普通递归快大约 20%,而不是像你只进行 Monkjack 的修正那样慢 10%。这是因为将小的 BigInts 相乘比将大的 BigInts 相乘要快。

于 2013-10-09T12:34:55.477 回答