5

我正在从 Java 到 Scala 重新实现一些代码(一个简单的贝叶斯推理算法,但这并不重要)。我想以尽可能高性能的方式实现它,同时通过尽可能避免可变性来保持代码的清洁和功能。

以下是 Java 代码片段:

    // initialize
    double lP  = Math.log(prior);
    double lPC = Math.log(1-prior);

    // accumulate probabilities from each annotation object into lP and lPC
    for (Annotation annotation : annotations) {
        float prob = annotation.getProbability();
        if (isValidProbability(prob)) {
            lP  += logProb(prob);
            lPC += logProb(1 - prob);
        }
    } 

很简单,对吧?所以我决定第一次尝试使用 Scala foldLeft 和 map 方法。因为我有两个值要累积,所以累加器是一个元组:

    val initial  = (math.log(prior), math.log(1-prior))
    val probs    = annotations map (_.getProbability)
    val (lP,lPC) = probs.foldLeft(initial) ((r,p) => {
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

不幸的是,这段代码的执行速度比 Java 慢了大约 5 倍(使用简单且不精确的度量;只是在循环中调用了 10000 次代码)。一个缺陷非常明显;我们遍历列表两次,一次在 map 调用中,另一次在 foldLeft 中。所以这是一个遍历列表一次的版本。

    val (lP,lPC) = annotations.foldLeft(initial) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

这个更好!它的性能比 Java 代码差大约 3 倍。我的下一个预感是,在折叠的每个步骤中创建所有新元组可能都会涉及一些成本。所以我决定尝试一个遍历列表两次但不创建元组的版本。

    val lP = annotations.foldLeft(math.log(prior)) ((r,annotation) => {
       val  p = annotation.getProbability
       if(isValidProbability(p)) r + logProb(p) else r
    })
    val lPC = annotations.foldLeft(math.log(1-prior)) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) r + logProb(1-p) else r
    })

这与以前的版本大致相同(比 Java 版本慢 3 倍)。并不奇怪,但我充满希望。

所以我的问题是,有没有更快的方法在 Scala 中实现这个 Java 代码片段,同时保持 Scala 代码干净,避免不必要的可变性并遵循 Scala 习语?我确实希望最终在并发环境中使用此代码,因此保持不变性的价值可能超过单线程中较慢的性能。

4

5 回答 5

4

首先,您的一些处罚可能来自您使用的收藏类型。但其中大部分可能是您实际上无法通过运行循环两次来避免的对象创建,因为必须将数字装箱。

相反,您可以创建一个可变类来为您累积值:

class LogOdds(var lp: Double = 0, var lpc: Double = 0) {
  def *=(p: Double) = {
    if (isValidProbability(p)) {
      lp += logProb(p)
      lpc += logProb(1-p)
    }
    this  // Pass self on so we can fold over the operation
  }
  def toTuple = (lp, lpc)
}

现在,尽管您可以不安全地使用它,但您不必这样做。事实上,你可以把它折叠起来。

annotations.foldLeft(new LogOdds()) { (r,ann) => r *= ann.getProbability } toTuple

如果你使用这种模式,所有可变的不安全性都被隐藏在折叠内;它永远不会逃脱。

现在,您不能进行平行折叠,但可以进行聚合,这就像折叠带有额外的操作来组合碎片。所以你添加方法

def **(lo: LogOdds) = new LogOdds(lp + lo.lp, lpc + lo.lpc)

LogOdds然后

annotations.aggregate(new LogOdds())(
  (r,ann) => r *= ann.getProbability,
  (l,r) => l**r
).toTuple

你会很高兴的。

(请随意使用非数学符号,但由于您基本上是在乘以概率,因此乘法符号似乎比合并概率或类似的东西更可能给出正在发生的事情的直观想法。)

于 2012-02-02T17:36:45.183 回答
3

您可以实现一个尾递归方法,该方法将由编译器转换为 while 循环,因此应该与 Java 版本一样快。或者,您可以只使用循环 - 如果它只是在方法中使用局部变量(例如,请参阅 Scala 集合源代码中的广泛使用),则没有法律禁止它。

def calc(lst: List[Annotation], lP: Double = 0, lPC: Double = 0): (Double, Double) = {
  if (lst.isEmpty) (lP, lPC)
  else {
    val prob = lst.head.getProbability
    if (isValidProbability(prob)) 
      calc(lst.tail, lP + logProb(prob), lPC + logProb(1 - prob))
    else 
      calc(lst.tail, lP, lPC)
  }
}

折叠的优点是它是可并行的,这可能会导致它比多核机器上的 Java 版本更快(参见其他答案)。

于 2012-02-02T18:10:45.293 回答
2

作为一种旁注:您可以避免使用以下惯用方式遍历列表两次view

val probs = annotations.view.map(_.getProbability).filter(isValidProbability)

val (lP, lPC) = ((logProb(prior), logProb(1 - prior)) /: probs) {
   case ((pa, ca), p) => (pa + logProb(p), ca + logProb(1 - p))
}

这可能不会让您获得比第三个版本更好的性能,但对我来说感觉更优雅。

于 2012-02-02T17:51:05.103 回答
2

首先,让我们解决性能问题:除了使用 while 循环之外,没有其他方法可以像 Java 一样快速地实现它。基本上,JVM 无法将 Scala 循环优化到优化 Java 循环的程度。其原因甚至是 JVM 人员关心的问题,因为它也妨碍了他们并行库的工作。

现在,回到 Scala 的性能,你也可以使用.view避免创建新集合的map步骤,但我认为该map步骤总是会导致性能变差。问题是,您正在将集合转换为一个参数化的 on Double,它必须被装箱和拆箱。

但是,有一种可能的优化方法:使其并行。如果您要求.parannotations其设为并行集合,则可以使用fold

val parAnnot = annotations.par
val lP = parAnnot.map(_.getProbability).fold(math.log(prior)) ((r,p) => {
   if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = parAnnot.map(_.getProbability).fold(math.log(1-prior)) ((r,p) => {
  if(isValidProbability(p)) r + logProb(1-p) else r
})

为避免单独的map步骤,请按照 Rex 的建议使用aggregate代替。fold

对于奖励积分,您可以使用Future使两个计算并行运行。不过,我怀疑通过带回元组并一次性运行它会获得更好的性能。你必须对这些东西进行基准测试,看看什么效果更好。

在并行集合上,它可能会首先filter获得有效的注释。或者,也许,collect

val parAnnot = annottions.par.view map (_.getProbability) filter (isValidProbability(_)) force;

或者

val parAnnot = annotations.par collect { case annot if isValidProbability(annot.getProbability) => annot.getProbability }

无论如何,基准。

于 2012-02-02T17:52:21.170 回答
1

当前无法在没有装箱的情况下与 scala 集合库进行交互。因此double,Java 中的原始 s 将在fold操作中不断地装箱和拆箱,即使您没有将它们包装在 a 中Tuple2(这专门的 - 但当然您已经支付了每次创建新对象的性能开销) .

于 2012-02-02T17:35:49.600 回答