我正在从 Java 到 Scala 重新实现一些代码(一个简单的贝叶斯推理算法,但这并不重要)。我想以尽可能高性能的方式实现它,同时通过尽可能避免可变性来保持代码的清洁和功能。
以下是 Java 代码片段:
// initialize
double lP = Math.log(prior);
double lPC = Math.log(1-prior);
// accumulate probabilities from each annotation object into lP and lPC
for (Annotation annotation : annotations) {
float prob = annotation.getProbability();
if (isValidProbability(prob)) {
lP += logProb(prob);
lPC += logProb(1 - prob);
}
}
很简单,对吧?所以我决定第一次尝试使用 Scala foldLeft 和 map 方法。因为我有两个值要累积,所以累加器是一个元组:
val initial = (math.log(prior), math.log(1-prior))
val probs = annotations map (_.getProbability)
val (lP,lPC) = probs.foldLeft(initial) ((r,p) => {
if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
})
不幸的是,这段代码的执行速度比 Java 慢了大约 5 倍(使用简单且不精确的度量;只是在循环中调用了 10000 次代码)。一个缺陷非常明显;我们遍历列表两次,一次在 map 调用中,另一次在 foldLeft 中。所以这是一个遍历列表一次的版本。
val (lP,lPC) = annotations.foldLeft(initial) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
})
这个更好!它的性能比 Java 代码差大约 3 倍。我的下一个预感是,在折叠的每个步骤中创建所有新元组可能都会涉及一些成本。所以我决定尝试一个遍历列表两次但不创建元组的版本。
val lP = annotations.foldLeft(math.log(prior)) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = annotations.foldLeft(math.log(1-prior)) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) r + logProb(1-p) else r
})
这与以前的版本大致相同(比 Java 版本慢 3 倍)。并不奇怪,但我充满希望。
所以我的问题是,有没有更快的方法在 Scala 中实现这个 Java 代码片段,同时保持 Scala 代码干净,避免不必要的可变性并遵循 Scala 习语?我确实希望最终在并发环境中使用此代码,因此保持不变性的价值可能超过单线程中较慢的性能。