1

我正在尝试提取我的嵌入矩阵并将其归一化以运行余弦相似度。按照这个 github 回购:
https ://github.com/s4sarath/Deep-Learning-Projects/blob/master/variational_text_inference/model_evaluation.ipynb

embedding_matrix = find_norm(embedding_matrix)

我为此定义了一个函数:

def find_norm(syn0):
    syn0norm = (syn0 / np.sqrt((syn0 ** 2).sum(-1))[..., np.newaxis]).astype(np.float32)
    #syn0norm = (syn0 / sqrt((syn0 ** 2).sum(-1))[..., newaxis]).astype(REAL)
    return syn0norm

但是当我运行它时,我得到了上述错误。有人可以帮我吗?

4

1 回答 1

0

您正在向函数提供一个 TensorFlowTensor对象,而该find_norm函数需要一个 numpy 数组。

您可以运行张量流图,提取图并将其提供给您的find_norm函数,或者您可以重写该函数以使用张量对象(并输出张量)。

于 2020-04-24T09:52:03.583 回答