5

我正在重构我的代码以利用DataFrames、Estimators 和 Pipelines最初在RDD[LabeledPoint]. 我喜欢学习和使用新的 API,但我不确定如何保存我的新模型并将其应用于新数据。

目前,ML 的实现LogisticRegression只支持二分类。我是这样使用OneVsRest的:

val lr = new LogisticRegression().setFitIntercept(true)
val ovr = new OneVsRest()
ovr.setClassifier(lr)
val ovrModel = ovr.fit(training)

我现在想保存我的OneVsRestModel,但这似乎不受 API 支持。我试过了:

ovrModel.save("my-ovr") // Cannot resolve symbol save
ovrModel.models.foreach(_.save("model-" + _.uid)) // Cannot resolve symbol save

有没有办法保存这个,所以我可以将它加载到一个新的应用程序中进行新的预测?

4

1 回答 1

5

火花 2.0.0

OneVsRestModel实现MLWritable,所以应该可以直接保存它。下面显示的方法对于单独保存单个模型仍然有用。

火花 < 2.0.0

这里的问题是models返回一个Arrayof ClassificationModel[_, _]]not an Arrayof LogisticRegressionModel(or MLWritable)。要使其工作,您必须具体说明类型:

import org.apache.spark.ml.classification.LogisticRegressionModel

ovrModel.models.zipWithIndex.foreach { 
  case (model: LogisticRegressionModel, i: Int) => 
    model.save(s"model-${model.uid}-$i")
}

或者更通用:

import org.apache.spark.ml.util.MLWritable

ovrModel.models.zipWithIndex.foreach { 
  case (model: MLWritable, i: Int) =>
    model.save(s"model-${model.uid}-$i")
}

不幸的是,目前(Spark 1.6)OneVsRestModel没有实现MLWritable,所以它不能单独保存。

注意

所有模型OneVsRest似乎都使用相同的,uid因此我们需要一个显式索引。稍后识别模型也很有用。

于 2016-03-27T03:40:08.020 回答