我正在尝试提取我的嵌入矩阵并将其归一化以运行余弦相似度。按照这个 github 回购:
https ://github.com/s4sarath/Deep-Learning-Projects/blob/master/variational_text_inference/model_evaluation.ipynb
embedding_matrix = find_norm(embedding_matrix)
我为此定义了一个函数:
def find_norm(syn0):
syn0norm = (syn0 / np.sqrt((syn0 ** 2).sum(-1))[..., np.newaxis]).astype(np.float32)
#syn0norm = (syn0 / sqrt((syn0 ** 2).sum(-1))[..., newaxis]).astype(REAL)
return syn0norm
但是当我运行它时,我得到了上述错误。有人可以帮我吗?