我想用网格搜索和火花交叉验证来调整我的模型。在spark中,必须将base model放到一个pipeline中,pipeline的office demo使用的LogistictRegression
是base model,可以new作为object。但是,该RandomForest
模型不能通过客户端代码新建,因此似乎无法RandomForest
在管道 api 中使用。我不想重新创建一个轮子,所以有人可以提供一些建议吗?谢谢
问问题
4241 次
1 回答
5
但是,RandomForest 模型不能通过客户端代码新建,因此在管道 api 中似乎无法使用 RandomForest。
嗯,这是真的,但你只是试图使用错误的类。而不是mllib.tree.RandomForest
你应该使用ml.classification.RandomForestClassifier
. 这是一个基于MLlib docs的示例。
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLUtils
import sqlContext.implicits._
case class Record(category: String, features: Vector)
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainData, testData) = (splits(0), splits(1))
val trainDF = trainData.map(lp => Record(lp.label.toString, lp.features)).toDF
val testDF = testData.map(lp => Record(lp.label.toString, lp.features)).toDF
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("label")
val rf = new RandomForestClassifier()
.setNumTrees(3)
.setFeatureSubsetStrategy("auto")
.setImpurity("gini")
.setMaxDepth(4)
.setMaxBins(32)
val pipeline = new Pipeline()
.setStages(Array(indexer, rf))
val model = pipeline.fit(trainDF)
model.transform(testDF)
有一件事我在这里想不通。据我所知,应该可以使用LabeledPoints
直接提取的标签,但由于某种原因它不起作用并pipeline.fit
引发IllegalArgumentExcetion
:
RandomForestClassifier 的输入带有无效的标签列标签,但未指定类的数量。
因此,丑陋的把戏与StringIndexer
. 应用后,我们获得了必需的属性 ( {"vals":["1.0","0.0"],"type":"nominal","name":"label"}
),但其中的某些类ml
似乎在没有它的情况下也能正常工作。
于 2015-08-20T05:21:05.070 回答