2

我目前正在使用 StringIndexer 将许多列转换为唯一整数,以便在 RandomForestModel 中进行分类。我还在为 ML 流程使用管道。

一些查询是

  1. RandomForestModel 如何知道哪些列是分类的。StringIndexer 将非数字转换为数字,但它是否添加了一些元数据以表明它是一个分类列?在 mllib.tree.RF 中有参数调用 categoricalInfo ,它指示列是分类的。ml.tree.RF 如何知道哪些是不存在的。

  2. 此外,StringIndexer 根据出现频率将类别映射到整数。现在,当新数据出现时,我如何确保这些数据与训练数据的编码一致?如果不再次对整个数据(包括新数据)进行 StringIndexing,我可以这样做吗?

我对如何实现这一点感到很困惑。

4

1 回答 1

3

如果不再次对整个数据(包括新数据)进行 StringIndexing,是否可以这样做?

对的,这是可能的。您只需要使用适合训练数据的索引器。如果您使用 ML 管道,它将为您直接使用StringIndexerModel

import org.apache.spark.ml.feature.StringIndexer

val train = sc.parallelize(Seq((1, "a"), (2, "a"), (3, "b"))).toDF("x", "y")
val test  = sc.parallelize(Seq((1, "a"), (2, "b"), (3, "b"))).toDF("x", "y")

val indexer = new StringIndexer()
  .setInputCol("y")
  .setOutputCol("y_index")
  .fit(train)

indexer.transform(train).show

// +---+---+-------+
// |  x|  y|y_index|
// +---+---+-------+
// |  1|  a|    0.0|
// |  2|  a|    0.0|
// |  3|  b|    1.0|
// +---+---+-------+

indexer.transform(test).show

// +---+---+-------+
// |  x|  y|y_index|
// +---+---+-------+
// |  1|  a|    0.0|
// |  2|  b|    1.0|
// |  3|  b|    1.0|
// +---+---+-------+

一个可能的警告是它不能优雅地处理看不见的标签,因此您必须在转换之前删除这些标签。

RandomForestModel 如何知道哪些列是分类的。

不同的 ML 转换器向转换后的列添加特殊的特殊元数据,这些元数据指示列的类型、类的数量等。

import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.VectorAssembler

val assembler = new VectorAssembler()
  .setInputCols(Array("x", "y_index"))
  .setOutputCol("features")

val transformed = assembler.transform(indexer.transform(train))
val meta = AttributeGroup.fromStructField(transformed.schema("features"))
meta.attributes.get

// Array[org.apache.spark.ml.attribute.Attribute] = Array(
//   {"type":"numeric","idx":0,"name":"x"},
//   {"vals":["a","b"],"type":"nominal","idx":1,"name":"y_index"})

或者

transformed.select($"features").schema.fields.last.metadata
// "ml_attr":{"attrs":{"numeric":[{"idx":0,"name":"x"}], 
//  "nominal":[{"vals":["a","b"],"idx":1,"name":"y_index"}]},"num_attrs":2}}
于 2015-12-03T16:53:12.563 回答