0

我正在尝试在 spark 中创建自己的自定义催化剂功能。这是一个实现,我希望函数的返回数据类型是 MapType。Literal(0)当我为地图中的所有键返回值时,我可以使用该函数。但我无法理解如何通过键从 mapType 属性参考(用于存储结果的缓冲区)中获取值以更新地图值。

我想通过在我的更新表达式函数中执行类似 customCount.get(nullKey) 的操作来获取值

case class CustomCount(children: Seq[Expression]) extends DeclarativeAggregate {

  override def nullable: Boolean = false

  override def dataType: MapType = {
    MapType(
      keyType = StringType,
      valueType = LongType,
      valueContainsNull = false)
  }

  private val customCountKey = "customCount"
  private val nullKey = "nullCount"
  private val notNullKey = "notNullCount"
  private val totalKey = "totalCount"

  private val mapType: MapType = MapType(
    keyType = StringType,
    valueType = LongType)

  private lazy val compositeCount = AttributeReference(customCountKey, mapType
  )()

  override lazy val aggBufferAttributes: Seq[AttributeReference] = compositeCount :: Nil

  override lazy val initialValues = Seq(
    CreateMap(Seq(nullKey, Literal(0L), notNullKey, Literal(0L), totalKey, Literal(0L)))
  )

  override lazy val mergeExpressions = Seq(
    // How to do this ?? Will need to merge two maps
    CreateMap(
      Seq(
        nullKey, Literal(0L),
        notNullKey, Literal(0L),
        totalKey, Literal(0L) // Does not matter, will get calculated at evaluate expression
      )
    )
  )

  override lazy val evaluateExpression: AttributeReference = compositeCount


  override def defaultResult: Option[Literal] = Option(Literal(Map(nullKey -> 0L, notNullKey -> 0L, totalKey -> 0L)))


  override lazy val updateExpressions = {
    val nullableChildren = children.filter(_.nullable)
    if (nullableChildren.isEmpty) {
      Seq(CreateMap(
        Seq(
          // How to do this ?? 
          nullKey, Literal(0L), // Would want to do something like customCount.get(nullKey)
          notNullKey, Literal(0L),
          totalKey, Literal(0L) 
        )
      ))
    } else {
      // Is any child null
      Seq(
        If(nullableChildren.map(IsNull).reduce(Or),
          CreateMap(
            Seq(
              nullKey, Literal(0L),
              notNullKey, Literal(0L),
              totalKey, Literal(0L) 
            )
          ),
          CreateMap(
            Seq(
              nullKey, Literal(0L),
              notNullKey, Literal(0L),
              totalKey, Literal(0L) 
            )
          ))
      )
    }
  }


  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CustomCount =
    copy(children = newChildren)
}

4

0 回答 0