我正在尝试将工作的流式TorchScript 模型导出到 ONNX,但遇到了一些问题。
我强调流意味着它有一些控制流,因此它需要作为torch.jit.script
AFAIK 导出。
...
with torch.no_grad():
stream_generator = StreamGenerator(training_config)
stream_generator.load_state_dict(state_dict_sg)
stream_generator.to(precision)
# stream_generator.eval()
# stream_generator.requires_grad_(False)
script_model = torch.jit.script(stream_generator)
# ONNX
torch.onnx.export(
model=script_model,
# model=stream_generator,
args=(data, {'is_final': True}),
f="stream_vocoder.onnx",
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=['data', 'is_final'],
output_names=['audio'],
dynamic_axes={'data': {0: 'batch_size'},
'audio': {0: 'batch_size'}},
verbose=True,
)
# Check
onnx_model = onnx.load("stream_vocoder.onnx")
onnx.checker.check_model(onnx_model)
问题是我得到了这个RuntimeError
Traceback (most recent call last):
File "onnx_conversion.py", line 78, in <module>
torch.onnx.export(
File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/__init__.py", line 275, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 88, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 689, in _export
_model_to_graph(model, args, verbose, input_names,
File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 458, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args,
File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 401, in _create_jit_graph
freezed_m = torch._C._freeze_module(model._c, preserveParameters=True)
RuntimeError: module contains attributes values that overlaps [ 0
[ torch.LongTensor{1} ], 0
[ torch.LongTensor{1} ], 0
[ torch.LongTensor{1} ]]
我查找错误并找到了这个https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/freeze_module.cpp#L348
我不知道从哪里开始寻找以解决这个问题。有人知道如何解决这个问题或解决它吗?