34

我有一个使用 Python 的 scikit-learn 训练的分类器。如何使用 Java 程序中的分类器?我可以使用 Jython 吗?有没有办法在 Python 中保存分类器并在 Java 中加载它?还有其他使用方法吗?

4

6 回答 6

52

您不能使用 jython,因为 scikit-learn 严重依赖具有许多已编译 C 和 Fortran 扩展的 numpy 和 scipy,因此无法在 jython 中工作。

在 java 环境中使用 scikit-learn 的最简单方法是:

  • 将分类器公开为 HTTP / Json 服务,例如使用诸如烧瓶瓶子檐口之类的微框架,并使用 HTTP 客户端库从 java 调用它

  • 在 python 中编写一个命令行包装应用程序,它使用诸如 CSV 或 JSON(或一些较低级别的二进制表示)之类的格式读取 stdin 上的数据并在 stdout 上输出预测,并从 java 调用 python 程序,例如使用Apache Commons Exec

  • 使python程序输出在拟合时学习的原始数值参数(通常作为浮点值数组)并在java中重新实现预测函数(这对于预测通常只是一个阈值点积的线性预测模型来说通常很容易) .

如果您还需要在 Java 中重新实现特征提取,则最后一种方法将做更多的工作。

最后,您可以使用实现所需算法的 Java 库(例如 Weka 或 Mahout),而不是尝试使用 Java 中的 scikit-learn。

于 2012-10-05T09:05:29.017 回答
23

为此目的有JPMML项目。

首先,您可以直接从 python 使用sklearn2pmml库将 scikit-learn 模型序列化为 PMML(内部是 XML),或者先将其转储到 python 中,然后使用java 中的 jpmml-sklearn或从该库提供的命令行进行转换。接下来,您可以在 Java 代码中使用jpmml-evaluator加载 pmml 文件、反序列化和执行加载的模型。

这种方式并非适用于所有 scikit-learn 模型,但适用于其中的许多模型。

于 2016-08-10T16:31:57.537 回答
6

您可以使用搬运工,我已经测试了 sklearn-porter ( https://github.com/nok/sklearn-porter ),它适用于 Java。

我的代码如下:

import pandas as pd
from sklearn import tree
from sklearn_porter import Porter

train_dataset = pd.read_csv('./result2.csv').as_matrix()

X_train = train_dataset[:90, :8]
Y_train = train_dataset[:90, 8:]

X_test = train_dataset[90:, :8]
Y_test = train_dataset[90:, 8:]

print X_train.shape
print Y_train.shape


clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

porter = Porter(clf, language='java')
output = porter.export(embed_data=True)
print(output)

就我而言,我使用的是 DecisionTreeClassifier,输出

打印(输出)

是以下代码作为控制台中的文本:

class DecisionTreeClassifier {

  private static int findMax(int[] nums) {
    int index = 0;
    for (int i = 0; i < nums.length; i++) {
        index = nums[i] > nums[index] ? i : index;
    }
    return index;
  }


  public static int predict(double[] features) {
    int[] classes = new int[2];

    if (features[5] <= 51.5) {
        if (features[6] <= 21.0) {

            // HUGE amount of ifs..........

        }
    }

    return findMax(classes);
  }

  public static void main(String[] args) {
    if (args.length == 8) {

        // Features:
        double[] features = new double[args.length];
        for (int i = 0, l = args.length; i < l; i++) {
            features[i] = Double.parseDouble(args[i]);
        }

        // Prediction:
        int prediction = DecisionTreeClassifier.predict(features);
        System.out.println(prediction);

    }
  }
}
于 2018-03-04T16:26:02.913 回答
3

以下是 JPMML 解决方案的一些代码:

--Python部分--

# helper function to determine the string columns which have to be one-hot-encoded in order to apply an estimator.
def determine_categorical_columns(df):
    categorical_columns = []
    x = 0
    for col in df.dtypes:
        if col == 'object':
            val = df[df.columns[x]].iloc[0]
            if not isinstance(val,Decimal):
                categorical_columns.append(df.columns[x])
        x += 1
    return categorical_columns

categorical_columns = determine_categorical_columns(df)
other_columns = list(set(df.columns).difference(categorical_columns))


#construction of transformators for our example
labelBinarizers = [(d, LabelBinarizer()) for d in categorical_columns]
nones = [(d, None) for d in other_columns]
transformators = labelBinarizers+nones

mapper = DataFrameMapper(transformators,df_out=True)
gbc = GradientBoostingClassifier()

#construction of the pipeline
lm = PMMLPipeline([
    ("mapper", mapper),
    ("estimator", gbc)
])

--JAVA部分--

//Initialisation.
String pmmlFile = "ScikitLearnNew.pmml";
PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);

//Determine which features are required as input
HashMap<String, Field>() inputFieldMap = new HashMap<String, Field>();
for (int i = 0; i < evaluator.getInputFields().size();i++) {
  InputField curInputField = evaluator.getInputFields().get(i);
  String fieldName = curInputField.getName().getValue();
  inputFieldMap.put(fieldName.toLowerCase(),curInputField.getField());
}


//prediction

HashMap<String,String> argsMap = new HashMap<String,String>();
//... fill argsMap with input

Map<FieldName, ?> res;
// here we keep only features that are required by the model
Map<FieldName,String> args = new HashMap<FieldName, String>();
Iterator<String> iter = argsMap.keySet().iterator();
while (iter.hasNext()) {
  String key = iter.next();
  Field f = inputFieldMap.get(key);
  if (f != null) {
    FieldName name =f.getName();
    String value = argsMap.get(key);
    args.put(name, value);
  }
}
//the model is applied to input, a probability distribution is obtained
res = evaluator.evaluate(args);
SegmentResult segmentResult = (SegmentResult) res;
Object targetValue = segmentResult.getTargetValue();
ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) targetValue;
于 2018-02-16T13:05:25.480 回答
1

我发现自己处于类似的境地。我会推荐创建一个分类器微服务。您可以有一个在 python 中运行的分类器微服务,然后通过一些 RESTFul API 公开对该服务的调用,从而产生 JSON/XML 数据交换格式。我认为这是一种更清洁的方法。

于 2018-05-11T12:50:16.663 回答
1

或者,您可以从经过训练的模型生成 Python 代码。这是一个可以帮助您解决问题的工具https://github.com/BayesWitnesses/m2cgen

于 2019-02-21T17:57:05.530 回答