我正在尝试计算所有大矩阵对(3m x 2048)之间的余弦距离,并使用 pytorch 提取前 30 个相似向量。以下是我的代码,它工作正常,但每次迭代大约需要 30 秒,这对于 300 万个词向量来说太长了。有什么想法可以加快速度吗?
import torch.nn.functional as F
import torch
from tqdm import tqdm
import gc
sym_dict={}
tmp_list=[]
tot_dict=torch.load('xbx.pt')
all_tensors = torch.cat([v.unsqueeze(0) for k,v in tot_dict.items()], dim=0)
token_list= [i for i in tot_dict.keys()]
del tot_dict
gc.collect()
for counter ,value in tqdm(enumerate(token_list)):
uniq_vec=torch.unsqueeze(all_tensors[counter],dim=0)
dist = 1 - F.cosine_similarity(uniq_vec,all_tensors)
index_sorted = torch.argsort(dist)
roll_me=index_sorted[:30].cpu().numpy().tolist()
for ind in roll_me:
tmp_list.append(token_list[ind])
sym_dict.update({value:tmp_list})
tmp_list=[]
#save .pt file
torch.save(sym_dict,'sym_dict.pt')