3

I have two nonnegative integers x and y, both of them have at most 30 bits (so their values are around 10^9).

I'd like to calculate how many sets of 4 numbers {a_1, a_2, a_3, a_4} are there such that a_1 + a_2 = x and a_3 + a_4 = y and xor of all those 4 numbers is equal to 0.

What is the fastest algorithm to solve this problem?

The fastest I can think of is rearranging the xor equation to a_1 xor a_2 = a_3 xor a_4.

Then I can calculate all values of the left hand side in O(x) and values of right hand side in O(y), so the whole algorithm runs in O(x + y).

4

1 回答 1

5

Let N(x, y) be the number of solutions of this problem. Obviously N(0, 0) is 1, since the only solution is (0, 0, 0, 0). And if either x or y is negative then there's no solutions, since we require a1, a2, a3, a4 to be all non-negative.

Otherwise, we can proceed by solving for the lowest bit, and generate a recurrence relation. Let's write n:0 and n:1 to mean 2n+0 and 2n+1 (so 0 and 1 are the lowest bits).

Then:

N(0, 0) = 1
N(-x, y) = N(x, -y) = 0
N(x:0, y:0) = N(x, y) + N(x-1, y) + N(x, y-1) + N(x-1, y-1)
N(x:0, y:1) = N(x:1, y:0) = 0
N(x:1, y:1) = 4 * N(x, y)

To see these, one has to consider possible low bits for any a1, a2, a3, a4.

Firstly N(x:0, y:0). We need the low bit of a1+a2 to be 0, which means that either both a1 and a2 are even, or they're both odd. If they're both odd, there's a carry and the sum of the higher bits plus 1 must sum to the higher bits of x. The same logic applies to a3, a4. There's 4 possibilities: all bottom bits of a1, a2, a3, a4 are 0, bottom bits of a1, a2 are 1, bottom bits of a3, a4 are 1, bottom bits of a1, a2, a3, a4 are 1. That's 4 cases.

Secondly N(x:0, y:1) and N(x:1, y:0). If one sum is even and the other odd, there's no solutions: one can check every combination for the lowest bits of a1, a2, a3, a4 to find out.

Thirdly N(x:1, y:1). Exactly one of a1 and a2 must be odd, and similarly exactly one of a3 and a4 must be odd. There's 4 possibilities for this, and no carry in any of the cases.

Here's a complete solution:

def N(x, y):
    if x == y == 0: return 1
    if x < 0 or y < 0: return 0
    if x % 2 == y % 2 == 0:
        return N(x//2, y//2) + N(x//2-1, y//2) + N(x//2, y//2-1) + N(x//2-1, y//2-1)
    elif x % 2 == y % 2 == 1:
        return 4 * N(x//2, y//2)
    else:
        return 0

The algorithm makes several recursive calls, so is in theory exponential. But in practice many of the branches terminate quickly, so the code runs plenty fast enough for values up to 2^30. But of course, you can add a cache or use a dynamic programming table to guarantee a runtime of O(log(x)+log(y)).

Finally, to increase confidence of correctness, here's some tests against a naive O(xy) solution:

def N_slow(x, y):
    s = 0
    for a1 in xrange(x + 1):
        for a3 in xrange(y + 1):
            a2 = x - a1
            a4 = y - a3
            if a1 ^ a2 ^ a3 ^ a4:
                continue
            s += 1
    return s

for x in xrange(50):
    for y in xrange(50):
        n = N(x, y)
        ns = N_slow(x, y)
        if n != ns:
            print 'N(%d, %d) = %d, want %d' % (x, y, n, ns)
于 2017-04-04T13:15:42.870 回答