列表推导是处理 numpy 数组的一种非常低效的方法。对于距离计算,它们是一个特别糟糕的选择。
要找到您的数据和点之间的差异,您只需执行data - point
. 然后,您可以使用 来计算距离np.hypot
,或者如果您愿意,可以将其平方、求和,然后取平方根。
不过,如果您将其设为 Nx2 数组以进行计算,则会更容易一些。
基本上,你想要这样的东西:
import numpy as np
data = np.array([[[1704, 1240],
[1745, 1244],
[1972, 1290],
[2129, 1395],
[1989, 1332]],
[[1712, 1246],
[1750, 1246],
[1964, 1286],
[2138, 1399],
[1989, 1333]],
[[1721, 1249],
[1756, 1249],
[1955, 1283],
[2145, 1399],
[1990, 1333]]])
point = [1989, 1332]
#-- Calculate distance ------------
# The reshape is to make it a single, Nx2 array to make calling `hypot` easier
dist = data.reshape((-1,2)) - point
dist = np.hypot(*dist.T)
# We can then reshape it back to AxBx1 array, similar to the original shape
dist = dist.reshape(data.shape[0], data.shape[1], 1)
print dist
这产生:
array([[[ 299.48121811],
[ 259.38388539],
[ 45.31004304],
[ 153.5219854 ],
[ 0. ]],
[[ 290.04310025],
[ 254.0019685 ],
[ 52.35456045],
[ 163.37074401],
[ 1. ]],
[[ 280.55837182],
[ 247.34186868],
[ 59.6405902 ],
[ 169.77926846],
[ 1.41421356]]])
现在,删除最近的元素比简单地获取最近的元素要困难一些。
使用 numpy,您可以使用布尔索引相当容易地做到这一点。
但是,您需要担心轴的对齐方式。
关键是要了解 numpy 沿最后一个轴的“广播”操作。在这种情况下,我们要沿中轴进行广播。
此外,-1
可以用作轴大小的占位符。-1
当作为轴的大小放入时,Numpy 将计算允许的大小。
我们需要做的看起来有点像这样:
#-- Remove closest point ---------------------
mask = np.squeeze(dist) != dist.min(axis=1)
filtered = data[mask]
# Once again, let's reshape things back to the original shape...
filtered = filtered.reshape(data.shape[0], -1, data.shape[2])
你可以把它写成一行,我只是为了便于阅读而把它分解。关键是dist != something
产生一个布尔数组,然后您可以使用它来索引原始数组。
所以,把它们放在一起:
import numpy as np
data = np.array([[[1704, 1240],
[1745, 1244],
[1972, 1290],
[2129, 1395],
[1989, 1332]],
[[1712, 1246],
[1750, 1246],
[1964, 1286],
[2138, 1399],
[1989, 1333]],
[[1721, 1249],
[1756, 1249],
[1955, 1283],
[2145, 1399],
[1990, 1333]]])
point = [1989, 1332]
#-- Calculate distance ------------
# The reshape is to make it a single, Nx2 array to make calling `hypot` easier
dist = data.reshape((-1,2)) - point
dist = np.hypot(*dist.T)
# We can then reshape it back to AxBx1 array, similar to the original shape
dist = dist.reshape(data.shape[0], data.shape[1], 1)
#-- Remove closest point ---------------------
mask = np.squeeze(dist) != dist.min(axis=1)
filtered = data[mask]
# Once again, let's reshape things back to the original shape...
filtered = filtered.reshape(data.shape[0], -1, data.shape[2])
print filtered
产量:
array([[[1704, 1240],
[1745, 1244],
[1972, 1290],
[2129, 1395]],
[[1712, 1246],
[1750, 1246],
[1964, 1286],
[2138, 1399]],
[[1721, 1249],
[1756, 1249],
[1955, 1283],
[2145, 1399]]])
附带说明,如果多个点同样接近,这将不起作用。Numpy 数组在每个维度上必须具有相同数量的元素,因此在这种情况下您需要重新进行分组。