3

我已将我的 PyTorch 模型导出到 ONNX。现在,我有没有办法从那个 ONNX 模型中获取输入层?

将 PyTorch 模型导出到 ONNX

import torch.onnx
checkpoint = torch.load("./saved_pytorch_model.pth")
model.load_state_dict(checkpoint['state_dict'])
input = torch.tensor(df_X.values).float()
torch.onnx.export(model, input, "onnx_model.onnx")

加载 ONNX 模型

onnx_model = onnx.load('onnx_model.onnx')

我希望能够以某种方式从 onnx_model 获取输入层。这可能吗?

4

2 回答 2

1

ONNX 模型是一个 protobuf 结构,定义见此处 ( https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto )。您可以使用为 python 生成的标准 protobuf 方法使用它(请参阅:https ://developers.google.com/protocol-buffers/docs/reference/python-generated )。我不明白你到底想提取什么。但是您可以遍历构成图的节点(model.graph.node)。图中的第一个节点可能对应于您可能认为的第一层,也可能不对应(这取决于翻译的完成方式)。您还可以获取模型的输入(model.graph.input)。

于 2019-08-06T21:08:19.003 回答
0

Onnx 库提供 API 来提取所有输入的名称和形状,如下所示:

model = onnx.load(onnx_model)
inputs = {}
for inp in model.graph.input:
    shape = str(inp.type.tensor_type.shape.dim)
    inputs[inp.name] = [int(s) for s in shape.split() if s.isdigit()]
于 2022-02-14T23:49:07.717 回答