我为我的想法设计了这个损失函数,但它运行得很慢。
def batch_rank_loss(clss, embs):
'''
clss.shape = [batch_size, 768]
embs.shape = [batch_size, 50]
'''
clss_norm = clss / clss.norm(dim=1)[:, None] # to get cosines
embs_norm = embs / embs.norm(dim=1)[:, None] # to get cosines
clss_cos = torch.mm(clss_norm, clss_norm.transpose(0, 1))
# cosine matrix for every cls with each other,shape = [batch_size, batch_size]
embs_cos = torch.mm(embs_norm, embs_norm.transpose(0, 1))
# cosine matrix for every emb with each other,shape = [batch_size, batch_size]
n = clss.shape[0]
loss = 0
for i in range(n):
for j in range(i+1, n):
for k in range(j+1, n):
loss += F.relu( -(clss_cos[i][j] - clss_cos[i][k]) * (embs_cos[i][j] - embs_cos[i][k]) )
return loss