我正在将一些复杂的 TF2 代码移植到 Pytorch。由于 TF2 不区分 Tensor 和 numpy 数组,所以说的很简单。然而,当我遇到几个错误说“你不能在 Pytorch 中混合 Tensor 和 numpy 数组!”时,我觉得我回到了 TF1 时代。这是原始的 TF2 代码:
def get_weighted_imgs(points, centers, imgs):
weights = np.array([[tf.norm(p - c) for c in centers] for p in points], dtype=np.float32)
weighted_imgs = np.array([[w * img for w, img in zip(weight, imgs)] for weight in weights])
weights = tf.expand_dims(1 / tf.reduce_sum(weights, axis=1), axis=-1)
weighted_imgs = tf.reshape(tf.reduce_sum(weighted_imgs, axis=1), [len(weights), 64*64*3])
return weights * weighted_imgs
还有我有问题的 Pytorch 代码:
def get_weighted_imgs(points, centers, imgs):
weights = torch.Tensor([[torch.norm(p - c) for c in centers] for p in points])
weighted_imgs = torch.Tensor([[w * img for w, img in zip(weight, imgs)] for weight in weights])
weights = torch.unsqueeze(1 / torch.sum(weights, dim=1), dim=-1)
weighted_imgs = torch.sum(weighted_imgs, dim=1).view([len(weights), 64*64*3])
return weights * weighted_imgs
def reproducible():
points = torch.Tensor(np.random.random((128, 5)))
centers = torch.Tensor(np.random.random((10, 5)))
imgs = torch.Tensor(np.random.random((10, 64, 64, 3)))
weighted_imgs = get_weighted_imgs(points, centers, imgs)
我可以保证张量/数组的维度顺序或形状没有问题。我得到的错误信息是
ValueError: only one element tensors can be converted to Python scalars
来自
weighted_imgs = torch.Tensor([[w * img for w, img in zip(weight, imgs)] for weight in weights])
有人可以帮我解决这个问题吗?那将不胜感激。