(警告:我对乘数的范围不是 [0, n)有点紧张,所以我调整了它。很容易弥补。)
我将使用经过测试的 Python 代码来绘制一个在O(log max {a, b})时间内运行的实现。首先,这里有一些实用函数和一个简单的实现。
from fractions import gcd
from random import randrange
def coprime(a, b):
return gcd(a, b) == 1
def floordiv(a, b):
return a // b
def ceildiv(a, b):
return floordiv(a + b - 1, b)
def count1(a, b, n, m):
assert 1 <= a < b
assert coprime(a, b)
assert 0 <= n < b + 1
assert 0 <= m < b + 1
return sum(k * a % b < m for k in range(n))
现在,我们怎样才能加快速度呢?第一个改进是将乘数划分为不相交的范围,使得在一个范围内,对应的 的倍数a
介于 的两个倍数之间b
。知道最低和最高值后,我们可以通过天花板除法计算小于 的倍数m
。
def count2(a, b, n, m):
assert 1 <= a < b
assert coprime(a, b)
assert 0 <= n < b + 1
assert 0 <= m < b + 1
count = 0
first = 0
while 0 < n:
count += min(ceildiv(m - first, a), n)
k = ceildiv(b - first, a)
n -= k
first = first + k * a - b
return count
这还不够快。第二个改进是用递归调用替换了大部分 while 循环。在下面的代码中,j
是“完整”的迭代次数,即存在环绕。term3
使用类似于 . 的逻辑来解释剩余的迭代count2
。
每个完整的迭代都在阈值下贡献floor(m / a)
或残差。我们是否得到取决于该迭代的内容。通过while循环在每次迭代中开始并以模数变化。只要它低于某个阈值,我们就会得到,并且这个计数可以通过递归调用来计算。floor(m / a) + 1
m
+ 1
first
first
0
a - (b % a)
a
+ 1
def count3(a, b, n, m):
assert 1 <= a < b
assert coprime(a, b)
assert 0 <= n < b + 1
assert 0 <= m < b + 1
if 1 == a:
return min(n, m)
j = floordiv(n * a, b)
term1 = j * floordiv(m, a)
term2 = count3(a - b % a, a, j, m % a)
last = n * a % b
first = last % a
term3 = min(ceildiv(m - first, a), (last - first) // a)
return term1 + term2 + term3
运行时间可以类似于欧几里得 GCD 算法进行分析。
这是一些测试代码来证明我的正确性声明。请记住在测试性能之前删除断言。
def test(p, f1, f2):
assert 3 <= p
for t in range(100):
while True:
b = randrange(2, p)
a = randrange(1, b)
if coprime(a, b):
break
for n in range(b + 1):
for m in range(b + 1):
args = (a, b, n, m)
print(args)
assert f1(*args) == f2(*args)
if __name__ == '__main__':
test(25, count1, count2)
test(25, count1, count3)