我想用 屏蔽分数矩阵中的所有零-np.inf,但我只能屏蔽部分零,看起来像
你在右上角看到仍然有没有被掩盖的零-np.inf
这是我的代码:
q = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
k = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
scores = torch.matmul(q, k.transpose(0,1)) / math.sqrt(10)
mask = torch.Tensor([1,1,1,1,0,0])
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask==0, -np.inf)
也许面具是错的?
