0

我有一个由以下数据组成的数据框

val df = List(
   (1,"wwe",List(1,2,3)),
   (2,"dsad",List.empty),
   (3,"dfd",null)).toDF("id","name","value")

df.show
+---+----+---------+
| id|name|    value|
+---+----+---------+
|  1| wwe|[1, 2, 3]|
|  2|dsad|       []|
|  3| dfd|     null|
+---+----+---------+

为了分解数组列值,我使用了以下逻辑

def explodeWithNull(f:StructField): Column ={
  explode(
    when(
      col(f.name).isNotNull, col(f.name)
    ).otherwise(
      f.dataType.asInstanceOf[ArrayType].elementType match{
        case StringType => array(lit(""))
        case DoubleType => array(lit(0.0))
        case IntegerType => array(lit(0))
        case _ => array(lit(""))
      }
    )
  )
} 
def explodeAllArraysColumns(dataframe: DataFrame): DataFrame = {
  val schema: StructType = dataframe.schema
  val arrayFileds: Seq[StructField] = schema.filter(f => f.dataType.typeName == "array")
  arrayFileds.foldLeft(dataframe) {
    (df: DataFrame, f: StructField) => df.withColumn(f.name,explodeWithNull(f))
  }
}

explodeAllArraysColumns(df).show
+---+----+-----+
| id|name|value|
+---+----+-----+
|  1| wwe|    1|
|  1| wwe|    2|
|  1| wwe|    3|
|  3| dfd|    0|
+---+----+-----+

以这种方式爆炸我错过了df中的空数组行。理想情况下,我不想错过那一行,我想要一个空值或爆炸数据框中该列的默认值。如何实现这一点?

4

2 回答 2

0
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql import Row
from pyspark.sql.types import ArrayType
from pyspark.sql.functions import *
from functools import reduce

    def explode_outer(df, columns_to_explode):
        array_fields = dict([(field.name, field.dataType)
                             for field in df.schema.fields
                             if type(field.dataType) == ArrayType])

        return reduce(lambda df_with_explode, column:
                      df_with_explode.withColumn(column, explode(
                          when(size(df_with_explode[column]) != 0, df_with_explode[column])
                              .otherwise(array(lit(None).cast(array_fields[column].elementType))))),
                      columns_to_explode, df)
于 2018-07-09T08:52:12.607 回答
0
from pyspark.sql.functions import *

def flatten_df(nested_df):
    flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
    nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']
    flat_df = nested_df.select(flat_cols +
                               [col(nc + '.' + c).alias(nc + '_' + c)
                                for nc in nested_cols
                                for c in nested_df.select(nc + '.*').columns])
    print("flatten_df_count :", flat_df.count())
    return flat_df

def explode_df(nested_df):
    flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct' and c[1][:5] != 'array']
    array_cols = [c[0] for c in nested_df.dtypes if c[1][:5] == 'array']
    for array_col in array_cols:
        schema = new_df.select(array_col).dtypes[0][1]
        nested_df = nested_df.withColumn(array_col, when(col(array_col).isNotNull(), col(array_col)).otherwise(array(lit(None)).cast(schema))) 
    nested_df = nested_df.withColumn("tmp", arrays_zip(*array_cols)).withColumn("tmp", explode("tmp")).select([col("tmp."+c).alias(c) for c in array_cols] + flat_cols)
    print("explode_dfs_count :", nested_df.count())
    return nested_df


new_df = flatten_df(myDf)
while True:
    array_cols = [c[0] for c in new_df.dtypes if c[1][:5] == 'array']
    if len(array_cols):
        new_df = flatten_df(explode_df(new_df))
    else:
        break
    
new_df.printSchema()

使用arrays_zipexplode解决了这个问题

于 2021-06-23T05:42:26.767 回答