0

我已经使用 Spark2.4 在 scala 中编写了 UDAF。由于我们的 Databricks 集群处于不再支持的 6.4 运行时,我们需要迁移到具有长期支持并使用 Spark3 的 7.3 LTS。UDAF 在 Spark3 中已弃用,将来会被删除(很可能)。所以我正在尝试将 UDAF 转换为聚合器函数

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{IntegerType,StringType, StructField, StructType, DataType}

object MaxCampaignIdAggregator extends UserDefinedAggregateFunction with java.io.Serializable{
  
  override def inputSchema: StructType = new StructType()
    .add("id", IntegerType, true)
    .add("name", StringType, true)

  def bufferSchema: StructType =  new StructType()
    .add("id", IntegerType, true)
    .add("name", StringType, true)

  // Returned Data Type .
  def dataType: DataType =  new StructType()
    .add("id", IntegerType, true)
    .add("name", StringType, true)

  // Self-explaining
  def deterministic: Boolean = true

  // This function is called whenever key changes
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = null
    buffer(1) = null
  }

  // Iterate over each entry of a group
  def update(buffer: MutableAggregationBuffer, inputRow: Row): Unit ={
      
      val inputId = inputRow.getAs[Int](0)
      val actualInputId = inputRow.get(0)
      val inputName = inputRow.getString(1)
      
      val bufferId = buffer.getAs[Int](0)
      val actualBufferId = buffer.get(0)
      val bufferName = buffer.getString(1)
      
      if(actualBufferId == null){
        buffer(0) = actualInputId
        buffer(1) = inputName
      }else if(actualInputId != null) {
        if(inputId > bufferId){
          buffer(0) = inputId
          buffer(1) = inputName
        }
      }  
  }

  // Merge two partial aggregates
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    
      val buffer1Id = buffer1.getAs[Int](0)
      val actualbuffer1Id = buffer1.get(0)
      val buffer1Name = buffer1.getString(1)
      
      val buffer2Id = buffer2.getAs[Int](0)
      val actualbuffer2Id = buffer2.get(0)
      val buffer2Name = buffer2.getString(1)
      
     if(actualbuffer1Id == null){
        buffer1(0) = actualbuffer2Id
        buffer1(1) = buffer2Name
     }else if(actualbuffer2Id != null){
        if(buffer2Id > buffer1Id){
          buffer1(0) = buffer2Id
          buffer1(1) = buffer2Name
        }
      }
    
  }

  // Called after all the entries are exhausted.
  def evaluate(buffer: Row): Any = {
    Row(buffer.get(0), buffer.getString(1))
  }
}

使用后,输出为:

{“id”:1282,“名称”:“麦考密克圣诞节”}

{“id”:1305,“名称”:“麦考密克完美捏”}

{"id": 1677, "name": "Viking Cruises Viking Cruises"}

4

0 回答 0