1

我正在使用 Faiss 来索引我的巨大数据集嵌入,嵌入从 bert 模型生成。我想逐步添加嵌入,如果我只用 faiss.IndexFlatL2 添加它就可以了,但问题是在保存它时它的大小太大了。所以我尝试使用 faiss.IndexIVFPQ,但它需要在添加数据之前训练嵌入,所以我不能增量添加它,我必须先计算所有嵌入然后训练并添加它,它有问题,因为所有数据应该保存在 RAM 中,直到我写下来。有没有办法逐步做到这一点。这是我的代码:

    # It is working fine when using with IndexFlatL2
    def __init__(self, sentences, model):
        self.sentences = sentences
        self.model = model
        self.index = faiss.IndexFlatL2(768)

    def process_sentences(self):
        result = self.model(self.sentences)
        self.sentence_ids = []
        self.token_ids = []
        self.all_tokens = []
        for i, (toks, embs) in enumerate(tqdm(result)):
            # initialize all_embeddings for every new sentence (INCREMENTALLY)
            all_embeddings = []
            for j, (tok, emb) in enumerate(zip(toks, embs)):
                self.sentence_ids.append(i)
                self.token_ids.append(j)
                self.all_tokens.append(tok)
                all_embeddings.append(emb)

            all_embeddings = np.stack(all_embeddings) # Add embeddings after every sentence
            self.index.add(all_embeddings)

        faiss.write_index(self.index, "faiss_Model")

当与 IndexIVFPQ 一起使用时:

   def __init__(self, sentences, model):
       self.sentences = sentences
       self.model = model
       self.quantizer = faiss.IndexFlatL2(768)
       self.index = faiss.IndexIVFPQ(self.quantizer, 768, 1000, 16, 8)

   def process_sentences(self):
       result = self.model(self.sentences)
       self.sentence_ids = []
       self.token_ids = []
       self.all_tokens = []
       all_embeddings = []
       for i, (toks, embs) in enumerate(tqdm(result)):
           for j, (tok, emb) in enumerate(zip(toks, embs)):
               self.sentence_ids.append(i)
               self.token_ids.append(j)
               self.all_tokens.append(tok)
               all_embeddings.append(emb)

       all_embeddings = np.stack(all_embeddings)
       self.index.train(all_embeddings) # Train
       self.index.add(all_embeddings) # Add to index
       faiss.write_index(self.index, "faiss_Model_mini")
4

0 回答 0