0

我正在尝试获取持久数据帧中前n行的列的总和。由于某种原因,以下方法不起作用:

val df = df0.sort(col("colB").desc).persist()
df.limit(2).agg(sum("colB")).show()

它显示了一个随机数,明显小于前两个的总和。数字随运行而变化。limit()在'ed DF上调用 show()确实始终显示正确的前两个值:

df.limit(2).show()

好像sort()在聚合之前不适用。这是 Spark 中的错误吗?我想这是一种预期的persist()丢失排序,但为什么它可以工作show()并且应该在某处记录?

4

1 回答 1

1

请参阅下面的查询计划。agg导致交换(物理计划中的第 4 行)删除排序,而show不会导致任何交换,因此保持排序。

scala> df.limit(2).agg(sum("colB")).explain()
== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[sum(cast(colB#4 as bigint))])
+- *(2) HashAggregate(keys=[], functions=[partial_sum(cast(colB#4 as bigint))])
   +- *(2) GlobalLimit 2
      +- Exchange SinglePartition, true, [id=#95]
         +- *(1) LocalLimit 2
            +- *(1) ColumnarToRow
               +- InMemoryTableScan [colB#4]
                     +- InMemoryRelation [colB#4], StorageLevel(disk, memory, deserialized, 1 replicas)
                           +- *(1) Sort [colB#4 DESC NULLS LAST], true, 0
                              +- Exchange rangepartitioning(colB#4 DESC NULLS LAST, 200), true, [id=#7]
                                 +- LocalTableScan [colB#4]


scala> df.limit(2).explain()
== Physical Plan ==
CollectLimit 2
+- *(1) ColumnarToRow
   +- InMemoryTableScan [colB#4]
         +- InMemoryRelation [colB#4], StorageLevel(disk, memory, deserialized, 1 replicas)
               +- *(1) Sort [colB#4 DESC NULLS LAST], true, 0
                  +- Exchange rangepartitioning(colB#4 DESC NULLS LAST, 200), true, [id=#7]
                     +- LocalTableScan [colB#4]

但是,如果您保留有限的数据帧,则不会对聚合进行任何交换,因此这可能会奏效:

val df1 = df.limit(2).persist()

scala> df1.agg(sum("colB")).explain()
== Physical Plan ==
*(1) HashAggregate(keys=[], functions=[sum(cast(colB#4 as bigint))])
+- *(1) HashAggregate(keys=[], functions=[partial_sum(cast(colB#4 as bigint))])
   +- *(1) ColumnarToRow
      +- InMemoryTableScan [colB#4]
            +- InMemoryRelation [colB#4], StorageLevel(disk, memory, deserialized, 1 replicas)
                  +- CollectLimit 2
                     +- *(1) ColumnarToRow
                        +- InMemoryTableScan [colB#4]
                              +- InMemoryRelation [colB#4], StorageLevel(disk, memory, deserialized, 1 replicas)
                                    +- *(1) Sort [colB#4 DESC NULLS LAST], true, 0
                                       +- Exchange rangepartitioning(colB#4 DESC NULLS LAST, 200), true, [id=#7]
                                          +- LocalTableScan [colB#4]

在任何情况下,如果行号满足特定条件(例如row_number <= 2),最好使用窗口函数来分配行号并对行求和。这将导致确定性的结果。例如,

df0.withColumn(
    "rn",
    row_number().over(Window.orderBy($"colB".desc))
).filter("rn <= 2").agg(sum("colB"))
于 2020-12-17T08:17:53.763 回答