0

Spark 3.0 已弃用UserDefinedAggregateFunction,我试图使用Aggregator. 的基本用法Aggregator很简单,但是,我很难使用更通用的函数版本。

我将尝试用这个例子来解释我的问题,一个collect_set. 这不是我的实际情况,但更容易解释问题:

class CollectSetDemoAgg(name: String) extends Aggregator[Row, Set[Int], Set[Int]] {
  override def zero = Set.empty
  override def reduce(b: Set[Int], a: Row) = b + a.getInt(a.fieldIndex(name))
  override def merge(b1: Set[Int], b2: Set[Int]) = b1 ++ b2
  override def finish(reduction: Set[Int]) = reduction
  override def bufferEncoder = Encoders.kryo[Set[Int]]
  override def outputEncoder = ExpressionEncoder()
}

// using it:
df.agg(new CollectSetDemoAgg("rank").toColumn as "result").show()

我更喜欢.toColumnvs .udf.register,但这不是重点。

问题: 我不能制作这个聚合器的通用版本,它只适用于整数。

我尝试过:

class CollectSetDemo(name: String) extends Aggregator[Row, Set[Any], Set[Any]] 

它因错误而崩溃:

No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
java.lang.UnsupportedOperationException: No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
    at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567)

我不能去CollectSetDemo[T],万一我不能正常outputEncoder。此外,在使用 udaf 时,我只能使用 Spark 数据类型、列等。

4

2 回答 2

1

还没有找到解决这种情况的好方法,但我能够在某种程度上解决它。代码部分借用自RowEncoder

class CollectSetDemoAgg(name: String, fieldType: DataType) extends Aggregator[Row, Set[Any], Any] {
  override def zero = Set.empty
  override def reduce(b: Set[Any], a: Row) = b + a.get(a.fieldIndex(name))
  override def merge(b1: Set[Any], b2: Set[Any]) = b1 ++ b2
  override def finish(reduction: Set[Any]) = reduction.toSeq
  override def bufferEncoder = Encoders.kryo[Set[Any]]

  // now
  override def outputEncoder = {
    val mirror = ScalaReflection.mirror
    val tt = fieldType match {
      case ArrayType(LongType, _) => typeTag[Seq[Long]]
      case ArrayType(IntegerType, _) => typeTag[Seq[Int]]
      case ArrayType(StringType, _) => typeTag[Seq[String]]
      // .. etc etc
      case _ => throw new RuntimeException(s"Could not create encoder for ${name} column (${fieldType})")
    }
    val tpe = tt.in(mirror).tpe

    val cls = mirror.runtimeClass(tpe)
    val serializer = ScalaReflection.serializerForType(tpe)
    val deserializer = ScalaReflection.deserializerForType(tpe)

    new ExpressionEncoder[Any](serializer, deserializer, ClassTag[Any](cls))
  }
}

我必须添加的一件事是聚合器中的结果数据类型参数。然后用法改为:

df.agg(new CollectSetDemoAgg("rank", new ArrayType(IntegerType, true)).toColumn as "result").show()

我真的不喜欢它的结果,但它确实有效。我也欢迎任何关于如何改进它的建议。

于 2020-09-16T14:04:46.343 回答
0

使用泛型修改@Ramunas答案:

class CollectSetDemoAgg[T: TypeTag](name: String) extends Aggregator[Row, Set[T], Seq[T]] {
  override def zero = Set.empty
  override def reduce(b: Set[T], a: Row) = b + a.getAs[T](a.fieldIndex(name))
  override def merge(b1: Set[T], b2: Set[T]) = b1 ++ b2
  override def finish(reduction: Set[T]) = reduction.toSeq
  override def bufferEncoder = Encoders.kryo[Set[T]]
  
  override def outputEncoder = {
    val tt = typeTag[Seq[T]]
    val tpe = tt.in(mirror).tpe

    val cls = mirror.runtimeClass(tpe)
    val serializer = serializerForType(tpe)
    val deserializer = deserializerForType(tpe)

    new ExpressionEncoder[Seq[T]](serializer, deserializer, ClassTag[Seq[T]](cls))
  }
}
于 2020-11-03T20:56:50.657 回答