2

问题陈述:

我有一个包含四列的数据框:service (String)、show (String)、country_1 (Integer) 和 country_2 (Integer)。我的目标是生成一个仅包含两列的数据框:服务(字符串)和信息(地图 [整数,列表 [字符串]])

其中映射可以包含多个键值对记录,例如每个流服务:

{
    "34521": ["The Crown", "Bridgerton", "The Queen's Gambit"],
    "49678": ["The Crown", "Bridgerton", "The Queen's Gambit"]
}

需要注意的重要一点是,将来可以添加更多国家/地区,例如输入数据框中的另外几列,例如“country_3”、“country_4”等。解决方案代码的目标也是希望能够解决这些问题而不仅仅是像我在下面尝试的解决方案中所做的那样对选定的列进行硬编码,如果这有意义的话。

输入数据框:

架构:

root
|-- service: string (nullable = true)
|-- show: string (nullable = true)
|-- country_1: integer (nullable = true)
|-- country_2: integer (nullable = true)

数据框:

service     |      show        |   country_1   |   country_2

Netflix      The Crown               34521           49678
Netflix      Bridgerton              34521           49678
Netflix      The Queen's Gambit      34521           49678
Peacock      The Office              34521           49678
Disney+      WandaVision             34521           49678 
Disney+      Marvel's 616            34521           49678
Disney+      The Mandalorian         34521           49678
Apple TV     Ted Lasso               34521           49678
Apple TV     The Morning Show        34521           49678

输出数据框:

架构:

root
|-- service: string (nullable = true)
|-- information: map (nullable = false)
|    |-- key: integer
|    |-- value: array (valueContainsNull = true)
|    |    |-- element: string (containsNull = true)

数据框:

service    |  information          

Netflix    [34521 -> [The Crown, Bridgerton, The Queen’s Gambit], 49678 -> [The Crown, Bridgerton, The Queen’s Gambit]] 
Peacock    [34521 -> [The Office], 49678 -> [The Office]]
Disney+    [34521 -> [WandaVision, Marvel’s 616, The Mandalorian], 49678 -> [WandaVision, Marvel’s 616, The Mandalorian]]
Apple TV   [34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]]

我已经尝试过的

虽然我已经成功地生成了我想要的输出并粘贴了代码片段,但我不想依赖使用非常基本的 SQL 类型的命令,因为我认为它对于大型数据集的快速计算并不总是最佳的,此外,我不想依赖我在映射时手动选择确切名称的国家列的方法,因为在以后可以添加更多国家列的意义上,这总是会发生变化。

有没有更好的方法来利用 udfs、foldLeft 等类型的代码或其他任何有助于优化的方法,也有助于代码更简洁而不是混乱?

val df = spark.read.parquet("filepath/*.parquet") 
val temp = df.groupBy("service", "country_1", "country_2").agg(collect_list("show").alias("show"))
val service_information = grouped.withColumn("information", map(lit($"country_1"), $"show", lit($"country_2"), $"show")).drop("country_1", "country_2", "show")
4

1 回答 1

1

根据评论部分中描述的国家数据“规范”(即任何给定列的所有行中的国家代码将相同且非空country_X),您的代码可以概括为处理任意多个国家列:

val df = Seq(
  ("Netflix",     "The Crown",             34521,    49678),
  ("Netflix",     "Bridgerton",            34521,    49678),
  ("Netflix",     "The Queen's Gambit",    34521,    49678),
  ("Peacock",     "The Office",            34521,    49678),
  ("Disney+",     "WandaVision",           34521,    49678),
  ("Disney+",     "Marvel's 616",          34521,    49678),
  ("Disney+",     "The Mandalorian",       34521,    49678),
  ("Apple TV",    "Ted Lasso",             34521,    49678),
  ("Apple TV",    "The Morning Show",      34521,    49678)
).toDF("service", "show", "country_1", "country_2")

val countryCols = df.columns.filter(_.startsWith("country_")).toList

val grouped = df.groupBy("service", countryCols: _*).agg(collect_list("show").as("shows"))

val service_information = grouped.withColumn(
    "information",
    map( countryCols.flatMap{ c => col(c) :: col("shows") :: Nil }: _* )
  ).drop("shows" :: countryCols: _*)

service_information.show(false)
// +--------+--------------------------------------------------------------------------------------------------------------+
// |service |information                                                                                                   |
// +--------+--------------------------------------------------------------------------------------------------------------+
// |Disney+ |[34521 -> [WandaVision, Marvel's 616, The Mandalorian], 49678 -> [WandaVision, Marvel's 616, The Mandalorian]]|
// |Peacock |[34521 -> [The Office], 49678 -> [The Office]]                                                                |
// |Netflix |[34521 -> [The Crown, Bridgerton, The Queen's Gambit], 49678 -> [The Crown, Bridgerton, The Queen's Gambit]]  |
// |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]]                              |
// +--------+--------------------------------------------------------------------------------------------------------------+

请注意,所描述的国家“规范”将要求所有shows 与相同的国家列表相关联。例如,如果您有 3 country_Xs 列并且给定的每一行country_X都是相同的,没有空值,这意味着每个show都与这 3 个国家相关联。如果您show只有 3 个国家/地区中的 2 个可用怎么办?


如果您的数据架构可以修改,维护相关国家信息的更灵活方法是为每个show.

val df = Seq(
  ("Netflix",     "The Crown",             Seq(34521, 49678)),
  ("Netflix",     "Bridgerton",            Seq(34521)),
  ("Netflix",     "The Queen's Gambit",    Seq(10001, 49678)),
  ("Peacock",     "The Office",            Seq(34521, 49678)),
  ("Disney+",     "WandaVision",           Seq(10001, 20002, 34521)),
  ("Disney+",     "Marvel's 616",          Seq(49678)),
  ("Disney+",     "The Mandalorian",       Seq(34521, 49678)),
  ("Apple TV",    "Ted Lasso",             Seq(34521, 49678)),
  ("Apple TV",    "The Morning Show",      Seq(20002, 34521))
).toDF("service", "show", "countries")

val grouped = df.withColumn("country", explode($"countries")).
  groupBy("service", "country").agg(collect_list($"show").as("shows"))

val service_information = grouped.groupBy("service").
  agg(collect_list($"country").as("c_list"), collect_list($"shows").as("s_list")).
  select($"service", map_from_arrays($"c_list", $"s_list").as("information"))

service_information.show(false)
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
// |service |information                                                                                                                        |
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
// |Peacock |[34521 -> [The Office], 49678 -> [The Office]]                                                                                     |
// |Disney+ |[20002 -> [WandaVision], 49678 -> [Marvel's 616, The Mandalorian], 34521 -> [WandaVision, The Mandalorian], 10001 -> [WandaVision]]|
// |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso], 20002 -> [The Morning Show]]                                        |
// |Netflix |[49678 -> [The Crown, The Queen's Gambit], 10001 -> [The Queen's Gambit], 34521 -> [The Crown, Bridgerton]]                        |
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
于 2021-02-10T04:39:07.203 回答