0

我为我的想法设计了这个损失函数,但它运行得很慢。

def batch_rank_loss(clss, embs):
    '''
    clss.shape = [batch_size, 768]
    embs.shape = [batch_size, 50]
    '''
    clss_norm = clss / clss.norm(dim=1)[:, None] # to get cosines
    embs_norm = embs / embs.norm(dim=1)[:, None] # to get cosines
    clss_cos = torch.mm(clss_norm, clss_norm.transpose(0, 1))
    # cosine matrix for every cls with each other,shape = [batch_size, batch_size]
    embs_cos = torch.mm(embs_norm, embs_norm.transpose(0, 1))
    # cosine matrix for every emb with each other,shape = [batch_size, batch_size]
    n = clss.shape[0]
    loss = 0
    for i in range(n):
        for j in range(i+1, n):
            for k in range(j+1, n):
                loss += F.relu( -(clss_cos[i][j] - clss_cos[i][k]) * (embs_cos[i][j] - embs_cos[i][k]) )
    return loss
4

0 回答 0