1

我有以下问题,实际上来自我最近参加的编码测试:

问题:

存在一个函数f(n) = a*n + b*n*(floor(log(n)/log(2))) + c*n*n*n

在一个特定的值,让f(n) = k;

给定k, a, b, c,找到n

对于给定的 值k,如果不n存在值,则返回 0。

限制:

1 <= n < 2^63-1
0 < a, b < 100
0 <= c < 100
0 < k < 2^63-1

这里的逻辑是,由于f(n)对于给定的 a、b 和 c 纯粹是增加的,我可以n通过二进制搜索找到。

我写的代码如下:

#include<iostream>
#include<stdlib.h>
#include<math.h>
using namespace std;

unsigned long long logToBase2Floor(unsigned long long n){
    return (unsigned long long)(double(log(n))/double(log(2)));
}

#define f(n, a, b, c) (a*n + b*n*(logToBase2Floor(n)) + c*n*n*n)


unsigned long long findNByBinarySearch(unsigned long long k, unsigned long long a, unsigned long long b, unsigned long long c){
    unsigned long long low = 1;
    unsigned long long high = (unsigned long long)(pow(2, 63)) - 1;
    unsigned long long n;
    while(low<=high){
        n = (low+high)/2;
        cout<<"\n\n          k= "<<k;
        cout<<"\n f(n,a,b,c)= "<<f(n,a,b,c)<<"  low = "<<low<<"  mid="<<n<<"  high = "<<high;
        if(f(n,a,b,c) == k)
            return n;
        else if(f(n,a,b,c) < k)
            low = n+1;
        else high = n-1;
    }
    return 0;
}

然后我用一些测试用例进行了尝试:

int main(){
    unsigned long long n, a, b, c;
    n = (unsigned long long)pow(2,63)-1;
    a = 99;
    b = 99;
    c = 99;
    cout<<"\nn="<<n<<"  a="<<a<<"  b="<<b<<"  c="<<c<<"    k = "<<f(n, a, b, c);
    cout<<"\nANSWER: "<<findNByBinarySearch(f(n, a, b, c), a, b, c)<<endl;
    n = 1000;
    cout<<"\nn="<<n<<"  a="<<a<<"  b="<<b<<"  c="<<c<<"    k = "<<f(n, a, b, c);
    cout<<"\nANSWER: "<<findNByBinarySearch(f(n, a, b, c), a, b, c)<<endl;
    return 0;
}

然后奇怪的事情发生了。

该代码适用于测试用例n = (unsigned long long)pow(2,63)-1;,正确返回 n 的值。但它没有为n=1000. 我打印了输出并看到以下内容:

n=1000  a=99  b=99  c=99    k = 99000990000

          k= 99000990000
 f(n,a,b,c)= 4611686018427387904  low = 1  mid=4611686018427387904  high = 9223372036854775807
 ...
 ...
          k= 99000990000
 f(n,a,b,c)= 172738215936  low = 1  mid=67108864  high = 134217727

          k= 99000990000
 f(n,a,b,c)= 86369107968  low = 1  mid=33554432  high = 67108863

          k= 99000990000
 f(n,a,b,c)= 129553661952  low = 33554433  mid=50331648  high = 67108863**
 ...
 ...
          k= 99000990000
 f(n,a,b,c)= 423215328047139441  low = 37748737  mid=37748737  high = 37748737
ANSWER: 0

在数学上似乎有些不对劲。的值怎么f(1000)大于 的值f(33554432)

所以我在 Python 中尝试了相同的代码,得到了以下值:

>>> f(1000, 99, 99, 99)
99000990000L
>>> f(33554432, 99, 99, 99)
3740114254432845378355200L

所以,价值肯定更大。

问题:

  • 究竟发生了什么?
  • 我该如何解决?

4

1 回答 1

2

究竟发生了什么?

问题在这里:

unsigned long long low = 1;
// Side note: This is simply (2ULL << 62) - 1
unsigned long long high = (unsigned long long)(pow(2, 63)) - 1;
unsigned long long n;
while (/* irrelevant */) {
    n = (low + high) / 2;
    // Some stuff that do not modify n... 
    f(n, a, b, c) // <-- Here!
}

在第一次迭代中,你有low = 1high = 2^63 - 1,这意味着n = 2^63 / 2 = 2^62。现在,让我们看看f

#define f(n, a, b, c) (/* I do not care about this... */ + c*n*n*n)

你有n^3in f, so for n = 2^62, n^3 = 2^186,这对你来说可能太大了unsigned long long(可能是 64 位长)。

我该如何解决?

这里的主要问题是进行二分查找时的溢出,因此您应该单独处理溢出的情况。

前言:我使用ull_t是因为我很懒,你应该避免在 C++ 中使用宏,更喜欢使用函数并让编译器内联它。另外,我更喜欢循环而不是使用该log函数来计算 an 的 log2 (有关andunsigned long long的实现,请参见此答案的底部)。log2is_overflow

using ull_t = unsigned long long;

constexpr auto f (ull_t n, ull_t a, ull_t b, ull_t c) {
    if (n == 0ULL) { // Avoid log2(0)
        return 0ULL;
    }
    if (is_overflow(n, a, b, c)) {
        return 0ULL;
    }
    return a * n + b * n * log2(n) + c * n * n * n;
}

这是稍微修改的二进制搜索版本:

constexpr auto find_n (ull_t k, ull_t a, ull_t b, ull_t c) {
    constexpr ull_t max = std::numeric_limits<ull_t>::max();
    auto lb = 1ULL, ub = (1ULL << 63) - 1;
    while (lb <= ub) {
        if (ub > max - lb) {
            // This should never happens since ub < 2^63 and lb <= ub so lb + ub < 2^64
            return 0ULL;
        }
        // Compute middle point (no overflow guarantee).
        auto tn = (lb + ub) / 2;
        // If there is an overflow, then change the upper bound.
        if (is_overflow(tn, a, b, c)) {
            ub = tn - 1;
        }
        // Otherwize, do a standard binary search...
        else {
            auto val = f(tn, a, b, c);
            if (val < k) {
                lb = tn + 1;
            }
            else if (val > k) {
                ub = tn - 1;
            }
            else {
                return tn;
            }
        }
    }
    return 0ULL;
}

正如您所看到的,这里只有一个测试是相关的,即is_overflow(tn, a, b, c)(第一个测试在lb + ub这里是无关紧要的ub < 2^63lb <= ub < 2^63所以在我们的例子ub + lb < 2^64中是可以unsigned long long的)。

完整实现:

#include <limits>
#include <type_traits>

using ull_t = unsigned long long;

template <typename T, 
          typename = std::enable_if_t<std::is_integral<T>::value>>
constexpr auto log2 (T n) {
    T log = 0;
    while (n >>= 1) ++log;
    return log;
}

constexpr bool is_overflow (ull_t n, ull_t a, ull_t b, ull_t c) {
    ull_t max = std::numeric_limits<ull_t>::max();
    if (n > max / a) {
        return true;
    }
    if (n > max / b) {
        return true;
    }
    if (b * n > max / log2(n)) {
        return true;
    }
    if (c != 0) {
        if (n > max / c) return true;
        if (c * n > max / n) return true;
        if (c * n * n > max / n) return true;
    }
    if (a * n > max - c * n * n * n) {
        return true;
    }
    if (a * n + c * n * n * n > max - b * n * log2(n)) {
        return true;
    }
    return false;
}

constexpr auto f (ull_t n, ull_t a, ull_t b, ull_t c) {
    if (n == 0ULL) {
        return 0ULL;
    }
    if (is_overflow(n, a, b, c)) {
        return 0ULL;
    }
    return a * n + b * n * log2(n) + c * n * n * n;
}

constexpr auto find_n (ull_t k, ull_t a, ull_t b, ull_t c) {
    constexpr ull_t max = std::numeric_limits<ull_t>::max();
    auto lb = 1ULL, ub = (1ULL << 63) - 1;
    while (lb <= ub) {
        if (ub > max - lb) {
            return 0ULL; // Problem here
        }
        auto tn = (lb + ub) / 2;
        if (is_overflow(tn, a, b, c)) {
            ub = tn - 1;
        }
        else {
            auto val = f(tn, a, b, c);
            if (val < k) {
                lb = tn + 1;
            }
            else if (val > k) {
                ub = tn - 1;
            }
            else {
                return tn;
            }
        }
    }
    return 0ULL;
}

编译时间检查:

下面是一小段代码,您可以使用它在编译时检查上述代码是否(因为一切都是constexpr):

template <unsigned long long n, unsigned long long a, 
          unsigned long long b, unsigned long long c>
struct check: public std::true_type {
    enum {
        k = f(n, a, b, c)
    };
    static_assert(k != 0, "Value out of bound for (n, a, b, c).");
    static_assert(n == find_n(k, a, b, c), "");
};

template <unsigned long long a, 
          unsigned long long b, 
          unsigned long long c>
struct check<0, a, b, c>: public std::true_type {
    static_assert(a != a, "Ambiguous values for n when k = 0.");
};

template <unsigned long long n>
struct check<n, 0, 0, 0>: public std::true_type {
    static_assert(n != n, "Ambiguous values for n when a = b = c = 0.");
};

#define test(n, a, b, c) static_assert(check<n, a, b, c>::value, "");

test(1000, 99, 99, 0);
test(1000, 99, 99, 99);
test(453333, 99, 99, 99);
test(495862, 99, 99, 9);
test(10000000, 1, 1, 0);

注意:的最大值k约为2^63,因此对于给定的三元组(a, b, c), 的最大值nf(n, a, b, c) < 2 ^ 63f(n + 1, a, b, c) >= 2 ^ 63。对于a = b = c = 99,这个最大值是n = 453333(根据经验找到的),这就是我在上面测试它的原因。

于 2016-08-10T06:48:22.537 回答