0

我有一个像这样的 DF:

time,channel,value
0,foo,5
0,bar,23
100,foo,42
...

我想要这样的DF:

time,foo,bar
0,5,23
100,42,...

在 Spark 2 中,我使用了这样的 UDAF:

case class ColumnBuilderUDAF(channels: Seq[String]) extends UserDefinedAggregateFunction {

  @transient lazy val inputSchema: StructType = StructType {
    StructField("channel", StringType, nullable = false) ::
      StructField("value", DoubleType, nullable = false) ::
      Nil
  }

  @transient lazy val bufferSchema: StructType = StructType {
    channels
      .toList
      .indices
      .map(i => StructField("c%d".format(i), DoubleType, nullable = false))
  }

  @transient lazy val dataType: DataType = bufferSchema

  @transient lazy val deterministic: Boolean = false

  def initialize(buffer: MutableAggregationBuffer): Unit = channels.indices.foreach(buffer(_) = NaN)

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val channel = input.getAs[String](0)
    val p = channels.indexOf(channel)
    if (p >= 0 && p < channels.length) {
      val v = input.getAs[Double](1)
      if (!v.isNaN) {
        buffer(p) = v
      }
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
    channels
      .indices
      .foreach { i =>
        val v2 = buffer2.getAs[Double](i)
        if ((!v2.isNaN) && buffer1.getAs[Double](i).isNaN) {
          buffer1(i) = v2
        }
      }

  def evaluate(buffer: Row): Any =
    new GenericRowWithSchema(channels.indices.map(buffer.getAs[Double]).toArray, dataType.asInstanceOf[StructType])
}

我这样使用:

val cb = ColumnBuilderUDAF(Seq("foo", "bar"))
val dfColumnar = df.groupBy($"time").agg(cb($"channel", $"value") as "c")

然后,我将c.c0等重命名c.c1foobar

在 Spark 3 中,UDAF 已被弃用,Aggregator应改为使用。所以我开始像这样移植它:

case class ColumnBuilder(channels: Seq[String]) extends Aggregator[(String, Double), Array[Double], Row] {

  lazy val bufferEncoder: Encoder[Array[Double]] = Encoders.javaSerialization[Array[Double]]

  lazy val zero: Array[Double] = channels.map(_ => Double.NaN).toArray

  def reduce(b: Array[Double], a: (String, Double)): Array[Double] = {
    val index = channels.indexOf(a._1)
    if (index >= 0 && !a._2.isNaN) b(index) = a._2
    b
  }

  def merge(b1: Array[Double], b2: Array[Double]): Array[Double] = {
    (0 until b1.length.min(b2.length)).foreach(i => if (b1(i).isNaN) b1(i) = b2(i))
    b1
  }

  def finish(reduction: Array[Double]): Row =
    new GenericRowWithSchema(reduction.map(x => x: Any), outputEncoder.schema)

  def outputEncoder: Encoder[Row] = ??? // what goes here?
}

我不知道如何实现,Encoder[Row]因为 Spark 没有预定义的。如果我只是简单地做一个这样的方法:

  val outputEncoder: Encoder[Row] = new Encoder[Row] {
    val schema: StructType = StructType(channels.map(StructField(_, DoubleType, nullable = false)))

    val clsTag: ClassTag[Row] = classTag[Row]
  }

我得到一个ClassCastException因为outputEncoder实际上必须是ExpressionEncoder

那么,我该如何正确实施呢?还是我仍然必须使用已弃用的 UDAF?

4

1 回答 1

1

You can do it with the use of groupBy and pivot

import spark.implicits._
import org.apache.spark.sql.functions._

val df = Seq(
  (0, "foo", 5),
  (0, "bar", 23),
  (100, "foo", 42)
).toDF("time", "channel", "value")

df.groupBy("time")
  .pivot("channel")
  .agg(first("value"))
  .show(false)

Output:

+----+----+---+
|time|bar |foo|
+----+----+---+
|100 |null|42 |
|0   |23  |5  |
+----+----+---+
于 2020-07-22T08:05:04.630 回答