考虑使用inline
and 高阶函数aggregate
(在 Spark 2.4+ 中可用)从 Array 类型的列中计算元素总和,然后使用 agroupBy/agg
将元素总和分组回数组:
val df = Seq(
(101, Seq(1, 2), Seq(3, 4), Seq(5, 6)),
(202, Seq(7, 8), Seq(9, 10), Seq(11, 12))
).toDF("id", "arr1", "arr2", "arr3")
val arrCols = df.columns.filter(_.startsWith("arr")).map(col)
对于 Spark 3.0+
df.
withColumn("arr_structs", arrays_zip(arrCols: _*)).
select($"id", expr("inline(arr_structs)")).
select($"id", aggregate(array(arrCols: _*), lit(0), (acc, x) => acc + x).as("pos_elem_sum")).
groupBy("id").agg(collect_list($"pos_elem_sum").as("arr_elem_sum")).
show
// +---+------------+
// | id|arr_elem_sum|
// +---+------------+
// |101| [9, 12]|
// |202| [27, 30]|
// +---+------------+
对于 Spark 2.4+
df.
withColumn("arr_structs", arrays_zip(arrCols: _*)).
select($"id", expr("inline(arr_structs)")).
select($"id", array(arrCols: _*).as("arr_pos_elems")).
select($"id", expr("aggregate(arr_pos_elems, 0, (acc, x) -> acc + x)").as("pos_elem_sum")).
groupBy("id").agg(collect_list($"pos_elem_sum").as("arr_elem_sum")).
show
对于 Spark 2.3 或更低版本
val sumArrElems = udf{ (arr: Seq[Int]) => arr.sum }
df.
withColumn("arr_structs", arrays_zip(arrCols: _*)).
select($"id", expr("inline(arr_structs)")).
select($"id", sumArrElems(array(arrCols: _*)).as("pos_elem_sum")).
groupBy("id").agg(collect_list($"pos_elem_sum").as("arr_elem_sum")).
show