2

我正在尝试将由 GPT2 预训练的 KoGPT2 模型转换为 onnx 格式,以便将模型更改为 tensorflow 格式。

我用过convert_graph_to_onnxtransformers但由于某些原因它不起作用。

我不知道这个错误意味着什么。这个模型可以制作onnx格式吗?这是我实现的代码,最后一个是错误。

谢谢。

import sys
!{sys.executable} -m pip install --upgrade git+https://github.com/huggingface/transformers
!{sys.executable} -m pip install --upgrade torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
!{sys.executable} -m pip install --upgrade onnxruntime==1.4.0
!{sys.executable} -m pip install -i https://test.pypi.org/simple/ ort-nightly
!{sys.executable} -m pip install --upgrade onnxruntime-tools
!rm -rf onnx/
from pathlib import Path
from transformers.convert_graph_to_onnx import convert

# Handles all the above steps for you
convert(framework="pt", model="skt/kogpt2-base-v2", output=Path('/content/drive/MyDrive/kogptonnx/kogpt.onnx'), opset=12)

# Tensorflow 
# convert(framework="tf", model="bert-base-cased", output="onnx/bert-base-cased.onnx", opset=11)
ONNX opset version set to: 11
Loading pipeline (model: skt/kogpt2-base-v2, tokenizer: skt/kogpt2-base-v2)
Some weights of the model checkpoint at skt/kogpt2-base-v2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Using framework PyTorch: 1.6.0+cpu
Found input input_ids with shape: {0: 'batch', 1: 'sequence'}
Found input attention_mask with shape: {0: 'batch', 1: 'sequence'}
Found output output_0 with shape: {0: 'batch', 1: 'sequence'}
Found output output_1 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_1 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_2 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_2 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_3 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_3 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_4 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_4 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_5 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_5 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_6 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_6 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_7 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_7 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_8 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_8 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_9 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_9 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_10 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_10 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_11 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_11 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_12 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Found output output_12 with shape: {0: 'batch', 1: 'sequence', 2: 'sequence'}
Ensuring inputs are in correct order
past_key_values is not present in the generated input list.
Generated inputs order: ['input_ids']
/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_gpt2.py:181: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_gpt2.py:186: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-78cc7242cbdd> in <module>()
      4 
      5 # Handles all the above steps for you
----> 6 convert(framework="pt", model="skt/kogpt2-base-v2", output=Path('/content/drive/MyDrive/kogptonnx/kogpt.onnx'), opset=11)
      7 
      8 # Tensorflow

6 frames
/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict)
    187 
    188         graph = torch._C._jit_pass_onnx(graph, operator_export_type)
--> 189         torch._C._jit_pass_lint(graph)
    190 
    191         torch._C._jit_pass_onnx_scalar_type_analysis(graph)

RuntimeError: Unable to cast from non-held to held instance (T& to Holder<T>) (compile in debug mode for type information)
4

1 回答 1

0

不确定它在这里是否有帮助,但是Unable to cast from non-held to held instance对于不同的模型(还有变压器),我有相同的错误消息,在我的情况下,添加operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK选项torch.onnx.export(...)如此处所述)为我修复了它:

torch.onnx.export(model, input, "output-name.onnx", export_params=True, opset_version=12, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
于 2021-11-14T09:40:24.523 回答