我想将我的 pyTorch 模型转换为 ONNX。但是,我收到一条错误消息
RuntimeError: 提供的输入名称数 (9) 超过了输入数 (7) 但是,如果我从模型中取出两个 Dropout 层,我的代码将完美运行。
为什么是这样?
这是我的代码:
# Define the model
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Dropout(0.2), # problem with dropout layer
torch.nn.Linear(H, H),
torch.nn.LeakyReLU(),
torch.nn.Dropout(0.2), # problem with dropout layer
torch.nn.Linear(H, D_out),
torch.nn.Sigmoid()
)
checkpoint = torch.load("./saved_pytorch_model.pth") # load pyTorch model
model.load_state_dict(checkpoint['state_dict'])
features = torch.Tensor(df_X.values[0])
# Convert pyTorch model to ONNX
input_names = ['input_1']
output_names = ['output_1']
for key, module in model._modules.items():
input_names.append("l_{}_".format(key) + module._get_name())
torch_out = torch.onnx.export(model,
features,
"onnx_model.onnx",
export_params = True,
verbose = True,
input_names = input_names,
output_names = output_names,
)
我该怎么做才能将其导出到包含 Dropout 的 ONNX?