编辑3:
事实证明,考虑所有可能的变化比我乍一看更复杂。我的代码的第三次迭代对于所有可能的输入都应该是正确的。由于复杂性增加,我放弃了矢量化的 numpy 变体。生成器版本如下:
def overlapping_sectors3(sectors, interval):
"""
Yields overlapping radial intervals.
Returns the overlapping intervals between each of the sector-intervals
and the comparison-interval.
Args:
sectors: List of intervals.
Interval borders must be in [0, 2*pi).
interval: Single interval aginst which the overlap is calculated.
Interval borders must be in [0, 2*pi).
Yields:
A list of intervals marking the overlaping areas.
Interval borders are guaranteed to be in [0, 2*pi).
"""
i_lhs, i_rhs = interval
if i_lhs > i_rhs:
for s_lhs, s_rhs in sectors:
if s_lhs > s_rhs:
# CASE 1
o_lhs = max(s_lhs, i_lhs)
# o_rhs = min(s_rhs+2*np.pi, i_rhs+2*np.pi)
o_rhs = min(s_rhs, i_rhs)
# since o_rhs > 2pi > o_lhs
yield o_lhs, o_rhs
#o_lhs = max(s_lhs+2pi, i_lhs)
# o_rhs = min(s_rhs+4pi, i_rhs+2pi)
# since o_lhs and o_rhs > 2pi
o_lhs = s_lhs
o_rhs = i_rhs
if o_lhs < o_rhs:
yield o_lhs, o_rhs
else:
# CASE 2
o_lhs = max(s_lhs, i_lhs)
# o_rhs = min(s_rhs, i_rhs+2*np.pi)
o_rhs = s_rhs # since i_rhs + 2pi > 2pi > s_rhs
if o_lhs < o_rhs:
yield o_lhs, o_rhs
# o_lhs = max(s_lhs+2pi, i_lhs)
# o_rhs = min(s_rhs+2pi, i_rhs+2pi)
# since s_lhs+2pi > 2pi > i_lhs and both o_lhs and o_rhs > 2pi
o_lhs = s_lhs
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs, o_rhs
else:
for s_lhs, s_rhs in sectors:
if s_lhs > s_rhs:
# CASE 3
o_lhs = max(s_lhs, i_lhs)
o_rhs = i_rhs
if o_lhs < o_rhs:
yield o_lhs, o_rhs
o_lhs = i_lhs
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs, o_rhs
else:
# CASE 4
o_lhs = max(s_lhs, i_lhs)
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs, o_rhs
它可以通过以下方式进行测试:
import numpy as np
from collections import namedtuple
TestCase = namedtuple('TestCase', ['sectors', 'interval', 'expected', 'remark'])
testcases = []
def newcase(sectors, interval, expected, remark=None):
testcases.append( TestCase(sectors, interval, expected, remark) )
newcase(
[[280,70]],
[270,90],
[[280,70]],
"type 1"
)
newcase(
[[10,150]],
[270,90],
[[10,90]],
"type 2"
)
newcase(
[[10,150]],
[270,350],
[],
"type 4"
)
newcase(
[[50,350]],
[10,90],
[[50,90]],
"type 4"
)
newcase(
[[30,0]],
[300,60],
[[30,60],[300,0]],
"type 1"
)
newcase(
[[30,5]],
[300,60],
[[30,60],[300,5]],
"type 1"
)
newcase(
[[30,355]],
[300,60],
[[30,60],[300,355]],
"type 3"
)
def isequal(A,B):
if len(A) != len(B):
return False
A = np.array(A).round()
B = np.array(B).round()
a = set(map(tuple, A))
b = set(map(tuple, B))
return a == b
for caseindex, case in enumerate(testcases):
print("### testcase %2d ###" % caseindex)
print("sectors : %s" % case.sectors)
print("interval: %s" % case.interval)
if case.remark:
print(case.remark)
sectors = np.array(case.sectors)/180*np.pi
interval = np.array(case.interval)/180*np.pi
result = overlapping_sectors3(sectors, interval)
result = np.array(list(result))*180/np.pi
if isequal(case.expected, result):
print('PASS')
else:
print('FAIL')
print('\texp: %s' % case.expected)
print('\tgot: %s' % result)
要理解其背后的逻辑,请考虑以下几点:
- 每个区间都有一个左手边 (lhs) 和一个右手边 (rhs)
- 如果 lhs > rhs 则区间“环绕”,即它实际上是区间 [lhs, rhs+2pi]
- 在比较当前扇区和比较区间时,我们必须考虑四种情况
- 都环绕
- 只有比较区间环绕
- 只有扇区间隔环绕
- 没有人环绕
- 对于普通区间,重叠区间是
[o_lhs, o_rhs]
witho_lhs=max(lhs_1, lhs2)
和 o _rhs=min(rhs_1, rhs_2)
iffo_lhs < o_rhs
2pi
通过添加到 rhs iffrhs<lhs
产生间隔来“展开”所有间隔[0, 4*np.pi)
- 我们称
[0,2*pi)
第一次、[2*pi, 4*pi)
第二次和[4*pi, 6*pi)
第三次循环。
四种情况:
- 情况 4:两个区间都没有环绕,所以所有边界都在第一个循环内。我们可以像计算任何原始间隔一样计算重叠。
- 案例 2 和 3:恰好一个区间环绕。这意味着一个区间(我们称之为 a)完全在第一个回旋之内,而第二个(我们称之为 b)同时产生第一个和第二个回旋。这意味着,a 可以在第一次和第二次循环中与 b 相交。首先我们考虑第一个循环。它包含 a_lhs、a_rhs 和 b_lhs。b 的右侧我们认为是“展开的”,因此在
b_rhs+2pi
. 这产生o_lhs=max(a_lhs, b_lhs)
和o_rhs=a_rhs
。现在我们考虑第二个循环。它不仅包含 b at 的 rhs,b_rhs+2pi
还包含 a at 的周期性重复[a_lhs+2pi, a_rhs+2pi]
。这导致o_lhs=max(a_lhs+2pi, b_lhs)
和o_rhs=min(a_rhs+2pi, b_rhs+2pi)
。模数2pi
向下移动到o_lhs=a_lhs
和o_rhs=min(a_rhs, b_rhs)
。
- 案例 1:两个区间都产生回旋一和二。第一个交点在
[0, 4pi)
第二个交点之内,需要周期性重复其中一个区间,因此位于[2pi,6pi)
.
旧答案,已弃用:
这是我的版本,使用 numpy 向量操作。它可能可以通过使用更抽象的 numpy 函数(如 np.where 等)来改进。
另一个想法是忽略 numpy 并使用某种迭代器/生成器函数。也许我接下来会尝试类似的事情。
import numpy as np
sectors = np.array( [[5.23,0.50], [0.7,1.8], [1.9,3.71],[4.1,5.11]] )
interval = np.array([5.7,2.15])
def normalize_sectors(sectors):
# normalize might not be the best word here
idx = sectors[...,0] > sectors[...,1]
sectors[idx,1] += 2*np.pi
return sectors
def overlapping_sectors(sectors, interval):
# 'reverse' modulo 2*pi, so that rhs is always larger than lhs"
sectors = normalize_sectors(sectors)
interval = normalize_sectors(interval.reshape(1,2)).squeeze()
# when comparing two intervals A and B, the intersection is
# [max(A.left, B.left), min(A.right, B.right)
left = np.maximum(sectors[:,0], interval[0])
right = np.minimum(sectors[:,1], interval[1])
# construct overlapping intervals
res = np.hstack([left,right]).reshape((2,-1)).T
# neither empty (lhs=rhs) nor 'reversed' lhs>rhs intervals are allowed
res = res[res[:,0] < res[:,1]]
#reapply modulo
res = res % (2*np.pi)
return res
print(overlapping_sectors(sectors, interval))
编辑:
这里是基于迭代器的版本。它同样有效,但在数值上似乎有些逊色。
def overlapping_sectors2(sectors, interval):
i_lhs, i_rhs = interval
if i_lhs>i_rhs:
i_rhs += 2*np.pi
for s_lhs, s_rhs in sectors:
if s_lhs>s_rhs:
s_rhs += 2*np.pi
o_lhs = max(s_lhs, i_lhs)
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs % (2*np.pi), o_rhs % (2*np.pi)
print(list(overlapping_sectors2(sectors, interval)))
EDIT2:
现在支持在两个地方重叠的间隔。
sectors = np.array( [[30,330]] )/180*np.pi
interval = np.array( [300,60] )/180*np.pi
def normalize_sectors(sectors):
# normalize might not be the best word here
idx = sectors[...,0] > sectors[...,1]
sectors[idx,1] += 2*np.pi
return sectors
def overlapping_sectors(sectors, interval):
# 'reverse' modulo 2*pi, so that rhs is always larger than lhs"
sectors = normalize_sectors(sectors)
# if interval rhs is smaller than lhs, the interval crosses 360 degrees
# and we have to consider it as two intervals
if interval[0] > interval[1]:
interval_1 = np.array([interval[0], 2*np.pi])
interval_2 = np.array([0, interval[1]])
res_1 = _overlapping_sectors(sectors, interval_1)
res_2 = _overlapping_sectors(sectors, interval_2)
res = np.vstack((res_1, res_2))
else:
res = _overlapping_sectors(sectors, interval)
#reapply modulo
res = res % (2*np.pi)
return res
def _overlapping_sectors(sector, interval):
# when comparing two intervals A and B, the intersection is
# [max(A.left, B.left), min(A.right, B.right)
left = np.maximum(sectors[:,0], interval[0])
right = np.minimum(sectors[:,1], interval[1])
# construct overlapping intervals
res = np.hstack([left,right]).reshape((2,-1)).T
# neither empty (lhs=rhs) nor 'reversed' lhs>rhs intervals are allowed
res = res[res[:,0] < res[:,1]]
return res
print(overlapping_sectors(sectors, interval)*180/np.pi)
def overlapping_sectors2(sectors, interval):
i_lhs, i_rhs = interval
for s_lhs, s_rhs in sectors:
if s_lhs>s_rhs:
s_rhs += 2*np.pi
if i_lhs > i_rhs:
o_lhs = max(s_lhs, i_lhs)
o_rhs = min(s_rhs, 2*np.pi)
if o_lhs < o_rhs:
yield o_lhs % (2*np.pi), o_rhs % (2*np.pi)
o_lhs = max(s_lhs, 0)
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs % (2*np.pi), o_rhs % (2*np.pi)
else:
o_lhs = max(s_lhs, i_lhs)
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs % (2*np.pi), o_rhs % (2*np.pi)
print(np.array(list(overlapping_sectors2(sectors, interval)))*180/np.pi)