0

我正在尝试修改下面的代码以采用第三个 Point 对象参数,但是这一行:

val cumulative = points.reduceLeft((a: Point, b: Point, c: Point) => 

导致此编译时错误:

Multiple markers at this line
    - type mismatch; found : (scala.algorithms.Point, scala.algorithms.Point, scala.algorithms.Point) => 
     scala.algorithms.Point required: (?, scala.algorithms.Point) => ?
    - type mismatch; found : (scala.algorithms.Point, scala.algorithms.Point, scala.algorithms.Point) => 
     scala.algorithms.Point required: (?, scala.algorithms.Point) => ?

整个代码:

package scala.algorithms

/**
 * Modified from http://garysieling.com/blog/implementing-k-means-in-scala
 * 
 */

class Point(val x: Double, val y: Double, val z : Double) {

  override def toString(): String = {
    "(" + x + ", " + y + ")"
  }

  def dist(p: Point): Double = {
    x * x + y * y + z * z
  }
}

object kmeans extends App {

   val NUMBER_OF_CLUSTERS = 5;

  val k: Int = 2

  val points: List[Point] = List(
    new Point(0, 0, 1),
    new Point(1, 0, 1),
    new Point(0, 1, 0)).sortBy(
      p => (p.x + " " + p.y).hashCode())

  def clusterMean(points: List[Point]): Point = {
    val cumulative = points.reduceLeft((a: Point, b: Point, c: Point) => 
      new Point(a.x + b.x + c.x, a.y + b.y + c.y , a.z + b.z + c.z))

    new Point(cumulative.x / points.length, cumulative.y / points.length
        , cumulative.z / points.length)
  }

  def render(points: Map[Int, List[Point]]) {
    for (clusterNumber <- points.keys.toSeq.sorted) {
      println("  Cluster " + clusterNumber)

      val meanPoint = clusterMean(points(clusterNumber))
      println("  Mean: " + meanPoint)

      for (j <- 0 to points(clusterNumber).length - 1) {
        System.out.println("    " + points(clusterNumber)(j) + ")")
      }
    }
  }

  val clusters =
    points.zipWithIndex.groupBy(
      x => x._2 % k) transform (
        (i: Int, p: List[(Point, Int)]) => for (x <- p) yield x._1)

  println("Initial State: ")
  render(clusters)

  def iterate(clusters: Map[Int, List[Point]]): Map[Int, List[Point]] = {
    val unzippedClusters =
      (clusters: Iterator[(Point, Int)]) => clusters.map(cluster => cluster._1)

    // find cluster means
    val means =
      (clusters: Map[Int, List[Point]]) =>
        for (clusterIndex <- clusters.keys)
          yield clusterMean(clusters(clusterIndex))

    // find the closest index
    def closest(p: Point, means: Iterable[Point]): Int = {
      val distances = for (center <- means) yield p.dist(center)
      distances.zipWithIndex.min._2
    }

    // assignment step
    val newClusters =
      points.groupBy(
        (p: Point) => closest(p, means(clusters)))

    render(newClusters)

    newClusters
  }

  var clusterToTest = clusters
  for (i <- 0 to NUMBER_OF_CLUSTERS) {
    System.out.println("Iteration: " + i)
    clusterToTest = iterate(clusterToTest)
  }
}

阅读来自http://www.scala-lang.org/api/current/index.html#index.index-r的 reduceLeft 方法的文档:

Applies a binary operator to all elements of this sequence, going left to right.

我想我需要改变这里使用的方法?

reduceLeft 方法也有多个特征:

IndexedSeqOptimized LinearSeqOptimized TraversableOnce TraversableProxyLike TraversableForwarder Stream ParIterableLike

, 我怎么知道正在实现哪个 trait/reduceLeft 实现?

4

1 回答 1

0

方法reduceLeft接受带有 2 个参数的函数作为参数,因此您应该像这样使用它:

points.reduce( (a, b) => new Point(a.x + b.x, a.y + b.y, a.z + b.z))

请注意,您会在 empty 上遇到异常points。您可以使用reduceOption或折叠以避免异常:

points.fold(new Point(0, 0, 0))( (a, b) => new Point(a.x + b.x, a.y + b.y, a.z + b.z))

您可以使用文档来调查方法的实现位置:

定义类 TraversableOnce

reduceLeft描述。

于 2013-07-12T11:38:26.253 回答