我有一个训练有素的 PyTorch 模型,现在我想使用 ONNX 将其导出到 Caffe2。这部分看起来相当简单且有据可查。但是,我现在想将该模型“加载”到 Java 程序中,以便在我的程序(Flink 流应用程序)中执行预测。做这个的最好方式是什么?我无法在网站上找到任何描述如何执行此操作的文档。
问问题
2771 次
1 回答
4
目前这有点棘手,但有办法。您将需要使用 JavaCPP:
- NGraph https://github.com/bytedeco/javacpp-presets/tree/master/ngraph
- ONNX https://github.com/bytedeco/javacpp-presets/tree/master/onnx
我将使用single_relu.onnx作为示例:
//read ONNX
byte[] bytes = Files.readAllBytes(Paths.get("single_relu.onnx"));
ModelProto model = new ModelProto();
ParseProtoFromBytes(model, new BytePointer(bytes), bytes.length); // parse ONNX -> protobuf model
//preprocess model in any way you like (you can skip this step)
check_model(model);
InferShapes(model);
StringVector passes = new StringVector("eliminate_nop_transpose", "eliminate_nop_pad", "fuse_consecutive_transposes", "fuse_transpose_into_gemm");
Optimize(model, passes);
check_model(model);
ConvertVersion(model, 8);
BytePointer serialized = model.SerializeAsString();
System.out.println("model="+serialized.getString());
//prepare nGraph backend
Backend backend = Backend.create("CPU");
Shape shape = new Shape(new SizeTVector(1,2 ));
Tensor input =backend.create_tensor(f32(), shape);
Tensor output =backend.create_tensor(f32(), shape);
Function ng_function = import_onnx_model(serialized); // convert ONNX -> nGraph
Executable exec = backend.compile(ng_function);
exec.call(new NgraphTensorVector(output), new NgraphTensorVector(input));
//collect result to array
float[] r = new float[2];
FloatPointer p = new FloatPointer(r);
output.read(p, 0, r.length * 4);
p.get(r);
//print result
System.out.println("[");
for (int i = 0; i < shape.get(0); i++) {
System.out.print(" [");
for (int j = 0; j < shape.get(1); j++) {
System.out.print(r[i * (int)shape.get(1) + j] + " ");
}
System.out.println("]");
}
System.out.println("]");
于 2019-06-21T12:28:34.920 回答