1

在 scala 中,我有一个返回值的函数列表。函数的执行顺序很重要,因为 function 的参数是 functionn的输出n-1

这提示使用foldLeft,例如:

val base: A
val funcs: Seq[Function[A, A]]

funcs.foldLeft(base)(x, f) => f(x)

(细节:类型A实际上是一个 Spark DataFrame)。

但是,每个函数的结果是互斥的,最后我想要每个函数的所有结果的并集。这提示使用 a map,例如:

funcs.map(f => f(base)).reduce(_.union(_)

但是这里每个函数都应用到base了我不想要的。

Short:一个可变长度的有序函数列表需要返回一个相等长度的返回值列表,其中每个值n-1都是函数的输入n(从basewhere开始n=0)。这样可以连接结果值。

我怎样才能做到这一点?

编辑 示例:

case class X(id:Int, value:Int)
val base = spark.createDataset(Seq(X(1, 1), X(2, 2), X(3, 3), X(4, 4), X(5, 5))).toDF

def toA = (x: DataFrame) => x.filter('value.mod(2) === 1).withColumn("value", lit("a"))
def toB = (x: DataFrame) => x.withColumn("value", lit("b"))

val a = toA(base)
val remainder = base.join(a, Seq("id"), "leftanti")
val b = toB(remainder)

a.union(b)

+---+-----+
| id|value|
+---+-----+
|  1|    a|
|  3|    a|
|  5|    a|
|  2|    b|
|  4|    b|
+---+-----+

这应该适用于任意数量的函数(例如toAtoB...。toN每次计算前一个结果的剩余部分并将其传递给下一个函数。最后,一个联合应用于所有结果。

4

2 回答 2

1

Seq已经有一个scanLeft开箱即用的方法:

funcs.scanLeft(base)((acc, f) => f(acc)).tail

scanLeft如果您不想base被包含,请确保删除结果的第一个元素。


仅使用 foldLeft 也是可能的:

funcs.foldLeft((base, List.empty[A])){ case ((x, list), f) => 
  val res = f(x)
  (res, res :: list) 
}._2.reverse.reduce(_.union(_))

或者:

funcs.foldLeft((base, Vector.empty[A])){ case ((x, list), f) => 
  val res = f(x)
  (res, list :+ res) 
}._2.reduce(_.union(_))

诀窍是积累成一个Seq里面的fold

例子:

scala> val base = 7
base: Int = 7

scala> val funcs: List[Int => Int] = List(_ * 2, _ + 3)
funcs: List[Int => Int] = List($$Lambda$1772/1298658703@7d46af18, $$Lambda$1773/107346281@5470fb9b)

scala> funcs.foldLeft((base, Vector.empty[Int])){ case ((x, list), f) => 
     |   val res = f(x)
     |   (res, list :+ res) 
     | }._2
res8: scala.collection.immutable.Vector[Int] = Vector(14, 17)

scala> .reduce(_ + _)
res9: Int = 31
于 2017-03-07T12:35:55.310 回答
0

我有一个使用普通集合的简化解决方案,但同样的原则适用。

val list: List[Int] = List(1, 2, 3, 4, 5)
val funcs: Seq[Function[List[Int], List[Int]]] = Seq(times2, by2)

funcs.foldLeft(list) { case(collection, func) => func(collection) } foreach println // prints 1 2 3 4 5

def times2(l: List[Int]): List[Int] = l.map(_ * 2)

def by2(l: List[Int]): List[Int] = l.map(_ / 2)

如果您想要一个减少的值作为最终输出,则此解决方案不成立,例如 single Int;因此这作为: F[B] -> F[B] -> F[B]而不是作为F[B] -> F[B] -> B; 虽然我想这是你需要的。

于 2017-03-07T12:23:46.973 回答