10

我有一个训练有素的 PyTorch 模型,现在我想使用 ONNX 将其导出到 Caffe2。这部分看起来相当简单且有据可查。但是,我现在想将该模型“加载”到 Java 程序中,以便在我的程序(Flink 流应用程序)中执行预测。做这个的最好方式是什么?我无法在网站上找到任何描述如何执行此操作的文档。

4

1 回答 1

4

目前这有点棘手,但有办法。您将需要使用 JavaCPP:

我将使用single_relu.onnx作为示例:

    //read ONNX
    byte[] bytes = Files.readAllBytes(Paths.get("single_relu.onnx"));
    ModelProto model = new ModelProto(); 
    ParseProtoFromBytes(model, new BytePointer(bytes), bytes.length); // parse ONNX -> protobuf model

    //preprocess model in any way you like (you can skip this step)
    check_model(model);
    InferShapes(model);
    StringVector passes = new StringVector("eliminate_nop_transpose", "eliminate_nop_pad", "fuse_consecutive_transposes", "fuse_transpose_into_gemm");
    Optimize(model, passes);
    check_model(model);
    ConvertVersion(model, 8);
    BytePointer serialized = model.SerializeAsString();
    System.out.println("model="+serialized.getString());

    //prepare nGraph backend
    Backend backend = Backend.create("CPU");
    Shape shape = new Shape(new SizeTVector(1,2 ));
    Tensor input =backend.create_tensor(f32(), shape);
    Tensor output =backend.create_tensor(f32(), shape);
    Function ng_function = import_onnx_model(serialized); // convert ONNX -> nGraph
    Executable exec = backend.compile(ng_function);
    exec.call(new NgraphTensorVector(output), new NgraphTensorVector(input));

    //collect result to array
    float[] r = new float[2];
    FloatPointer p = new FloatPointer(r);
    output.read(p, 0, r.length * 4);
    p.get(r);

    //print result
    System.out.println("[");
    for (int i = 0; i < shape.get(0); i++) {
        System.out.print(" [");
        for (int j = 0; j < shape.get(1); j++) {
            System.out.print(r[i * (int)shape.get(1) + j] + " ");
        }
        System.out.println("]");
    }
    System.out.println("]");
于 2019-06-21T12:28:34.920 回答