我已经使用 Spark2.4 在 scala 中编写了 UDAF。由于我们的 Databricks 集群处于不再支持的 6.4 运行时,我们需要迁移到具有长期支持并使用 Spark3 的 7.3 LTS。UDAF 在 Spark3 中已弃用,将来会被删除(很可能)。所以我正在尝试将 UDAF 转换为聚合器函数
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{IntegerType,StringType, StructField, StructType, DataType}
object MaxCampaignIdAggregator extends UserDefinedAggregateFunction with java.io.Serializable{
override def inputSchema: StructType = new StructType()
.add("id", IntegerType, true)
.add("name", StringType, true)
def bufferSchema: StructType = new StructType()
.add("id", IntegerType, true)
.add("name", StringType, true)
// Returned Data Type .
def dataType: DataType = new StructType()
.add("id", IntegerType, true)
.add("name", StringType, true)
// Self-explaining
def deterministic: Boolean = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = null
buffer(1) = null
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, inputRow: Row): Unit ={
val inputId = inputRow.getAs[Int](0)
val actualInputId = inputRow.get(0)
val inputName = inputRow.getString(1)
val bufferId = buffer.getAs[Int](0)
val actualBufferId = buffer.get(0)
val bufferName = buffer.getString(1)
if(actualBufferId == null){
buffer(0) = actualInputId
buffer(1) = inputName
}else if(actualInputId != null) {
if(inputId > bufferId){
buffer(0) = inputId
buffer(1) = inputName
}
}
}
// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val buffer1Id = buffer1.getAs[Int](0)
val actualbuffer1Id = buffer1.get(0)
val buffer1Name = buffer1.getString(1)
val buffer2Id = buffer2.getAs[Int](0)
val actualbuffer2Id = buffer2.get(0)
val buffer2Name = buffer2.getString(1)
if(actualbuffer1Id == null){
buffer1(0) = actualbuffer2Id
buffer1(1) = buffer2Name
}else if(actualbuffer2Id != null){
if(buffer2Id > buffer1Id){
buffer1(0) = buffer2Id
buffer1(1) = buffer2Name
}
}
}
// Called after all the entries are exhausted.
def evaluate(buffer: Row): Any = {
Row(buffer.get(0), buffer.getString(1))
}
}
使用后,输出为:
{“id”:1282,“名称”:“麦考密克圣诞节”}
{“id”:1305,“名称”:“麦考密克完美捏”}
{"id": 1677, "name": "Viking Cruises Viking Cruises"}