1

使用 pytorch 的 torchtext 构建文本分类模型。词汇对象在 data.field 中:

def create_tabularDataset_object(self,csv_path):
   self.TEXT = data.Field(tokenize=self.tokenizer,batch_first=True,include_lengths=True)
   self.LABEL = data.LabelField(dtype = torch.float,batch_first=True)
def get_vocab_with_glov(self,data):
   # initialize glove embeddings
   self.TEXT.build_vocab(data,min_freq=100,vectors = "glove.6B.100d")

训练后,在生产中为模型提供服务时,我如何持有 TEXT 对象?在预测时我需要它来索引单词标记

[TEXT.vocab.stoi[t] for t in tokenizedׁ_sentence]

我是否遗漏了某些东西并且没有必要持有该对象?除了模型重量,我还需要其他文件吗?

4

2 回答 2

5

实际上最好的方法(更稳定)是使用火炬内置函数 torch.save(*)

保存文件示例:

torch.save(vocab_obj, 'vocab_obj.pth')

再次加载文件:

vocab_obj = torch.load('vocab_obj.pth')
于 2021-05-01T11:55:30.753 回答
2

我发现我可以将它保存为 pkl:将 TEXT.vocab 保存为 pkl 工作:

def save_vocab(vocab, path):
    import pickle
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

在哪里

vocab = TEXT.vocab 

并像往常一样阅读它。

于 2020-04-16T12:40:59.083 回答