0

我正在尝试将工作的流式TorchScript 模型导出到 ONNX,但遇到了一些问题。

我强调流意味着它有一些控制流,因此它需要作为torch.jit.scriptAFAIK 导出。

...
    with torch.no_grad():
        stream_generator = StreamGenerator(training_config)
        stream_generator.load_state_dict(state_dict_sg)
        stream_generator.to(precision)
        # stream_generator.eval()
        # stream_generator.requires_grad_(False)
        script_model = torch.jit.script(stream_generator)

        # ONNX
        torch.onnx.export(
            model=script_model,
            # model=stream_generator,
            args=(data, {'is_final': True}),
            f="stream_vocoder.onnx",
            export_params=True,
            opset_version=13,
            do_constant_folding=True,
            input_names=['data', 'is_final'],
            output_names=['audio'],
            dynamic_axes={'data': {0: 'batch_size'},
                          'audio': {0: 'batch_size'}},
            verbose=True,
        )

        # Check
        onnx_model = onnx.load("stream_vocoder.onnx")
        onnx.checker.check_model(onnx_model)

问题是我得到了这个RuntimeError

Traceback (most recent call last):
  File "onnx_conversion.py", line 78, in <module>
    torch.onnx.export(
  File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/__init__.py", line 275, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 689, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 458, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args,
  File "/home/aalvarez/.virtualenvs/tts-inference-comp-klfduu67-py3.8/lib/python3.8/site-packages/torch/onnx/utils.py", line 401, in _create_jit_graph
    freezed_m = torch._C._freeze_module(model._c, preserveParameters=True)
RuntimeError: module contains attributes values that overlaps [ 0
[ torch.LongTensor{1} ],  0
[ torch.LongTensor{1} ],  0
[ torch.LongTensor{1} ]]

我查找错误并找到了这个https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/freeze_module.cpp#L348

我不知道从哪里开始寻找以解决这个问题。有人知道如何解决这个问题或解决它吗?

4

0 回答 0