10

我已经能够创建一个管道,允许我一次索引多个字符串列,但我无法对它们进行编码,因为与索引不同,编码器不是估计器,所以我从不根据OneHotEncoder 中的示例调用 fit文档

import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, 

OneHotEncoder}
import org.apache.spark.ml.Pipeline

val data = sqlContext.read.parquet("s3n://map2-test/forecaster/intermediate_data")

val df = data.select("win","bid_price","domain","size", "form_factor").na.drop()


//indexing columns
val stringColumns = Array("domain","size", "form_factor")
val index_transformers: Array[org.apache.spark.ml.PipelineStage] = stringColumns.map(
  cname => new StringIndexer()
    .setInputCol(cname)
    .setOutputCol(s"${cname}_index")
)

// Add the rest of your pipeline like VectorAssembler and algorithm
val index_pipeline = new Pipeline().setStages(index_transformers)
val index_model = index_pipeline.fit(df)
val df_indexed = index_model.transform(df)


//encoding columns
val indexColumns  = df_indexed.columns.filter(x => x contains "index")
val one_hot_encoders: Array[org.apache.spark.ml.PipelineStage] = indexColumns.map(
    cname => new OneHotEncoder()
     .setInputCol(cname)
     .setOutputCol(s"${cname}_vec")
)



val one_hot_pipeline = new Pipeline().setStages(one_hot_encoders)
val df_encoded = one_hot_pipeline.transform(df_indexed)

OneHotEncoder 对象没有 fit 方法,因此将其与索引器放在同一管道中将不起作用 - 当我在管道上调用 fit 时会引发错误。我也不能在使用管道阶段数组创建的管道上调用转换,one_hot_encoders.

我还没有找到一个很好的解决方案来使用 OneHotEncoder 而不单独为我要编码的所有列创建和调用转换本身

4

1 回答 1

6

火花 >= 3.0

在 Spark 3.0OneHotEncoderEstimator中已重命名为OneHotEncoder

import org.apache.spark.ml.feature.{OneHotEncoder, OneHotEncoderModel}

val encoder = new OneHotEncoder()
  .setInputCols(indexColumns)
  .setOutputCols(indexColumns map (name => s"${name}_vec"))

火花 >= 2.3

Spark 2.3 引入了新的类OneHotEncoderEstimatorOneHotEncoderModel即使在外面使用也需要拟合Pipeline,并且同时对多个列进行操作。

import org.apache.spark.ml.feature.{OneHotEncoderEstimator, OneHotEncoderModel}

val encoder = new OneHotEncoderEstimator()
  .setInputCols(indexColumns)
  .setOutputCols(indexColumns map (name => s"${name}_vec"))


encoder.fit(df_indexed).transform(df_indexed)

火花 < 2.3

即使您使用的转换器不需要拟合,您也必须使用fit方法来创建PipelineModel可用于转换数据的方法。

one_hot_pipeline.fit(df_indexed).transform(df_indexed)

在旁注中,您可以将索引和编码组合成一个Pipeline

val pipeline = new Pipeline()
  .setStages(index_transformers ++ one_hot_encoders)

val model = pipeline.fit(df)
model.transform(df)

编辑

您看到的错误意味着您的一列包含一个空的String. 它被索引器接受,但不能用于编码。根据您的要求,您可以删除这些或使用虚拟标签。不幸的是,在SPARK-11569)得到解决NULLs之前,您无法使用。

于 2015-12-08T22:22:08.033 回答