Let N(x, y)
be the number of solutions of this problem. Obviously N(0, 0)
is 1, since the only solution is (0, 0, 0, 0). And if either x
or y
is negative then there's no solutions, since we require a1, a2, a3, a4 to be all non-negative.
Otherwise, we can proceed by solving for the lowest bit, and generate a recurrence relation. Let's write n:0
and n:1
to mean 2n+0 and 2n+1 (so 0 and 1 are the lowest bits).
Then:
N(0, 0) = 1
N(-x, y) = N(x, -y) = 0
N(x:0, y:0) = N(x, y) + N(x-1, y) + N(x, y-1) + N(x-1, y-1)
N(x:0, y:1) = N(x:1, y:0) = 0
N(x:1, y:1) = 4 * N(x, y)
To see these, one has to consider possible low bits for any a1, a2, a3, a4.
Firstly N(x:0, y:0)
. We need the low bit of a1+a2 to be 0, which means that either both a1 and a2 are even, or they're both odd. If they're both odd, there's a carry and the sum of the higher bits plus 1 must sum to the higher bits of x. The same logic applies to a3, a4. There's 4 possibilities: all bottom bits of a1, a2, a3, a4 are 0, bottom bits of a1, a2 are 1, bottom bits of a3, a4 are 1, bottom bits of a1, a2, a3, a4 are 1. That's 4 cases.
Secondly N(x:0, y:1)
and N(x:1, y:0)
. If one sum is even and the other odd, there's no solutions: one can check every combination for the lowest bits of a1, a2, a3, a4 to find out.
Thirdly N(x:1, y:1)
. Exactly one of a1 and a2 must be odd, and similarly exactly one of a3 and a4 must be odd. There's 4 possibilities for this, and no carry in any of the cases.
Here's a complete solution:
def N(x, y):
if x == y == 0: return 1
if x < 0 or y < 0: return 0
if x % 2 == y % 2 == 0:
return N(x//2, y//2) + N(x//2-1, y//2) + N(x//2, y//2-1) + N(x//2-1, y//2-1)
elif x % 2 == y % 2 == 1:
return 4 * N(x//2, y//2)
else:
return 0
The algorithm makes several recursive calls, so is in theory exponential. But in practice many of the branches terminate quickly, so the code runs plenty fast enough for values up to 2^30. But of course, you can add a cache or use a dynamic programming table to guarantee a runtime of O(log(x)+log(y)).
Finally, to increase confidence of correctness, here's some tests against a naive O(xy) solution:
def N_slow(x, y):
s = 0
for a1 in xrange(x + 1):
for a3 in xrange(y + 1):
a2 = x - a1
a4 = y - a3
if a1 ^ a2 ^ a3 ^ a4:
continue
s += 1
return s
for x in xrange(50):
for y in xrange(50):
n = N(x, y)
ns = N_slow(x, y)
if n != ns:
print 'N(%d, %d) = %d, want %d' % (x, y, n, ns)