0

我尝试过的一个选项是腌制词汇并使用额外文件 arg 保存

import torch
import pickle

class Vocab(object):
    pass

vocab = Vocab()
pickle.dump(open('path/to/vocab.pkl','w'))

m = torch.jit.ScriptModule()

## I am not sure about the usage of this arg, the docs didn't help me
extra_files = torch._C.ExtraFilesMap()
extra_files['vocab.pkl'] = 'path/to/vocab.pkl'
# I also tried  pickle.dumps(vocab), and directly vocab

torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

## Load with extra files.
files = {'vocab.pkl': ''}
torch.jit.load('scriptmodule.pt', _extra_files = files)

这给了

TypeError: import_ir_module(): incompatible function arguments. The following argument types are supported:
    1. (arg0: Callable[[List[str]], torch._C.ScriptModule], arg1: str, arg2: object, arg3: torch._C.ExtraFilesMap) -> None

其他选项显然是单独加载泡菜,但我正在寻找单个文件选项。

如果一个人可以将词汇添加到火炬脚本中,那就太好了......如果有一些我显然不知道的不这样做的原因,也很高兴知道。

4

3 回答 3

1

我认为文档torch.jit.load不正确。您需要创建一个 ExtraFilesmap() 对象来加载保存的文件。

以下是我如何开始工作的示例: 第 1 步:保存模型

extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
traced_script_module.save(serialized_model_path, _extra_files=extra_files)

第 2 步:加载模型

files = torch._C.ExtraFilesMap()
files['foo.txt'] = ''
loaded_model = torch.jit.load(serialized_model_path, _extra_files=files)
print(files)
于 2019-07-17T18:43:52.180 回答
0

问题位于 torch.jit.load 中。尝试检查您的 map_location

于 2019-06-03T01:09:53.803 回答
0

假设vocab是受支持的类型,您可以将其作为TorchScript 属性添加到模型中,以将其与模型一起存储在 1 个文件中(因此您不必处理_extra_files.

然后你的加载代码变成

torch.jit.load('scriptmodule.pt')
于 2019-07-01T21:07:52.767 回答