10

假设我想在 Scala 中使用可变映射来跟踪我看到某些字符串的次数。在单线程上下文中,这很容易:

import scala.collection.mutable.{ Map => MMap }

class Counter {
  val counts = MMap.empty[String, Int].withDefaultValue(0)

  def add(s: String): Unit = counts(s) += 1
}

不幸的是,这不是线程安全的,因为getandupdate不会自动发生。

并发映射向可变映射 API添加了一些原子操作,但不是我需要的,它看起来像这样:

def replace(k: A, f: B => B): Option[B]

我知道我可以使用ScalaSTMTMap

import scala.concurrent.stm._

class Counter {
  val counts =  TMap.empty[String, Int]

  def add(s: String): Unit = atomic { implicit txn =>
    counts(s) = counts.get(s).getOrElse(0) + 1
  }
}

但是(目前)这仍然是一个额外的依赖项。其他选项将包括参与者(另一个依赖项)、同步(可能效率较低)或 Java 的原子引用较少惯用)。

一般来说,我会避免在 Scala 中使用可变映射,但我偶尔会需要这种东西,最近我使用了 STM 方法(而不是只是交叉手指,希望我不会被天真解决方案)。

我知道这里有很多权衡(额外的依赖与性能与清晰度等),但在 Scala 2.10 中是否有类似“正确”的答案?

4

4 回答 4

10

这个怎么样?假设您现在并不真正需要通用replace方法,只需要一个计数器。

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger

object CountedMap {
  private val counts = new ConcurrentHashMap[String, AtomicInteger]

  def add(key: String): Int = {
    val zero = new AtomicInteger(0)
    val value = Option(counts.putIfAbsent(key, zero)).getOrElse(zero)
    value.incrementAndGet
  }
}

您可以获得比在整个地图上同步更好的性能,并且您还可以获得原子增量。

于 2013-08-09T15:48:43.843 回答
3

最简单的解决方案肯定是同步。如果没有太多争用,性能可能不会那么差。

否则,您可以尝试汇总您自己的类似 STM 的replace实现。这样的事情可能会做:

object ConcurrentMapOps {
  private val rng = new util.Random
  private val MaxReplaceRetryCount = 10
  private val MinReplaceBackoffTime: Long = 1
  private val MaxReplaceBackoffTime: Long = 20
}
implicit class ConcurrentMapOps[A, B]( val m: collection.concurrent.Map[A,B] ) {
  import ConcurrentMapOps._
  private def replaceBackoff() {
    Thread.sleep( (MinReplaceBackoffTime + rng.nextFloat * (MaxReplaceBackoffTime - MinReplaceBackoffTime) ).toLong ) // A bit crude, I know
  }

  def replace(k: A, f: B => B): Option[B] = {
    m.get( k ) match {
      case None => return None
      case Some( old ) =>
        var retryCount = 0
        while ( retryCount <= MaxReplaceRetryCount ) {
          val done = m.replace( k, old, f( old ) )
          if ( done ) {
            return Some( old )
          }
          else {         
            retryCount += 1
            replaceBackoff()
          }
        }
        sys.error("Could not concurrently modify map")
    }
  }
}

请注意,冲突问题仅限于给定的键。如果两个线程访问同一个映射但使用不同的键,则不会发生冲突,并且替换操作总是第一次成功。如果检测到冲突,我们稍等片刻(随机时间,以尽量减少线程永远争夺同一个密钥的可能性)然后重试。

我不能保证这是生产就绪的(我现在只是把它扔掉了),但这可能会奏效。

更新:当然(正如 Ionuț G. Stan 指出的那样),如果你想要的只是增加/减少一个值,javaConcurrentHashMap已经以无锁方式提供了这些操作。replace如果您需要一种将转换函数作为参数的更通用的方法,则我的上述解决方案适用。

于 2013-08-09T15:46:34.610 回答
2

如果您的地图只是作为 val 坐在那里,您就是在自找麻烦。如果它符合您的用例,我会推荐类似的东西

class Counter {
  private[this] myCounts = MMap.empty[String, Int].withDefaultValue(0)
  def counts(s: String) = myCounts.synchronized { myCounts(s) }
  def add(s: String) = myCounts.synchronized { myCounts(s) += 1 }
  def getCounts = myCounts.synchronized { Map[String,Int]() ++ myCounts }
}

用于低争用用途。对于高争用,您应该使用旨在支持此类使用的并发映射(例如java.util.concurrent.ConcurrentHashMap)并将值包装在AtomicWhatever.

于 2013-08-09T17:15:27.760 回答
2

如果您可以使用基于未来的界面:

trait SingleThreadedExecutionContext {
  val ec = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor())
}

class Counter extends SingleThreadedExecutionContext {
  private val counts = MMap.empty[String, Int].withDefaultValue(0)

  def get(s: String): Future[Int] = future(counts(s))(ec)

  def add(s: String): Future[Unit] = future(counts(s) += 1)(ec)
}

测试将如下所示:

class MutableMapSpec extends Specification {

  "thread safe" in {

    import ExecutionContext.Implicits.global

    val c = new Counter
    val testData = Seq.fill(16)("1")
    await(Future.traverse(testData)(c.add))
    await(c.get("1")) mustEqual 16
  }
}
于 2013-08-09T18:21:59.600 回答