0

我正在尝试使用辅助 API 构建 onnx 图。我开始的最简单的例子如下。一个 MatMul 运算,它采用两个 [1] 矩阵输入(X 和 W),并产生 [1] 矩阵输出 Y。

import numpy as np
import onnxruntime as rt
from onnx import *
from onnxmltools.utils import save_mode

initializer = []
initializer.append(helper.make_tensor(name="W", data_type=TensorProto.FLOAT, dims=(1,), vals=np.ones(1).tolist()))

graph = helper.make_graph(
    [
        helper.make_node('MatMul', ["X", "W"], ["Y"]),
    ],
    "TEST",
    [
        helper.make_tensor_value_info('X' , TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('W', TensorProto.FLOAT, [1]),
    ],
    [
        helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1]),
    ],
    initializer=initializer,
    )

checker.check_graph(graph)
model = helper.make_model(graph, producer_name='TEST')
save_model(model, "model.onnx")
sess = rt.InferenceSession('model.onnx')

当我运行它时,它会这样抱怨:

Traceback (most recent call last):
File "onnxruntime_test.py", line 35, in <module>
sess = rt.InferenceSession('model.onnx')
File "/usr/local/lib/python3.5/dist-packages/onnxruntime/capi/session.py", line 29, in __init__
self._sess.load_model(path_or_bytes)
RuntimeError: [ONNXRuntimeError] : 1 : GENERAL ERROR : Node: Output:Y [ShapeInferenceError] Mismatch between number of source and target dimensions. Source=0 Target=1

我被困在这里几个小时。有人可以给我任何帮助吗?

4

1 回答 1

0

https://github.com/microsoft/onnxruntime/issues/380

我更改了一些地方以使您的代码正常工作。下面是新的

import numpy as np
import onnxruntime as rt
from onnx import *
from onnx import utils
initializer = []
initializer.append(helper.make_tensor(name="W", data_type=TensorProto.FLOAT, dims=(1,), vals=np.ones(1).tolist()))

graph = helper.make_graph(
    [
        helper.make_node('MatMul', ["X", "W"], ["Y"]),
    ],
    "TEST",
    [
        helper.make_tensor_value_info('X' , TensorProto.FLOAT, [1]),
        helper.make_tensor_value_info('W', TensorProto.FLOAT, [1]),
    ],
    [
        helper.make_tensor_value_info('Y', TensorProto.FLOAT, []),
    ],
    initializer=initializer,
    )

checker.check_graph(graph)
model = helper.make_model(graph, producer_name='TEST')
final_model = onnx.utils.polish_model(model)
onnx.save(final_model, 'model.onnx')
sess = rt.InferenceSession('model.onnx')

要表示标量,您应该使用“[]”的形状,而不是“[1]”。

于 2019-06-09T04:15:25.583 回答