0

我试图找到一种最快的方法来计算 C++ 中任何浮点数的平方根。我在一个巨大的粒子运动计算中使用这种类型的函数,比如计算两个粒子之间的距离,我们需要一个平方根等。所以如果有任何建议会非常有帮助。我试过了,下面是我的代码

#include <math.h>
#include <iostream>
#include <chrono>

using namespace std;
using namespace std::chrono;

#define CHECK_RANGE 100

inline float msqrt(float a)
{
    int i;
    for (i = 0;i * i <= a;i++);
    
    float lb = i - 1; //lower bound
    if (lb * lb == a)
        return lb;
    float ub = lb + 1; // upper bound
    float pub = ub; // previous upper bound
    for (int j = 0;j <= 20;j++)
    {
        float ub2 = ub * ub;
        if (ub2 > a)
        {
            pub = ub;
            ub = (lb + ub) / 2; // mid value of lower and upper bound
        }
        else
        {
            lb = ub; 
            ub = pub;
        }
    }
    return ub;
}

void check_msqrt()
{
    for (size_t i = 0; i < CHECK_RANGE; i++)
    {
        msqrt(i);
    }
}

void check_sqrt()
{
    for (size_t i = 0; i < CHECK_RANGE; i++)
    {
        sqrt(i);
    }
}

int main()
{
    auto start1 = high_resolution_clock::now();
    check_msqrt();
    auto stop1 = high_resolution_clock::now();

    auto duration1 = duration_cast<microseconds>(stop1 - start1);
    cout << "Time for check_msqrt = " << duration1.count() << " micro secs\n";


    auto start2 = high_resolution_clock::now();
    check_sqrt();
    auto stop2 = high_resolution_clock::now();

    auto duration2 = duration_cast<microseconds>(stop2 - start2);
    cout << "Time for check_sqrt = " << duration2.count() << " micro secs";
    
    //cout << msqrt(3);

    return 0;
}

上述代码的输出显示实现的方法比 math.h 文件的 sqrt 慢 4 倍。我需要比 math.h 版本更快的版本。 在此处输入图像描述

4

3 回答 3

3

简而言之,我认为不可能实现比标准库版本更快的东西sqrt

在实现标准库函数时,性能是一个非常重要的参数,可以公平地假设这样一个常用的函数sqrt被尽可能地优化。

击败标准库函数需要特殊情况,例如:

  • 在标准库尚未专门针对的特定系统上提供合适的汇编指令 - 或其他专门的硬件支持。
  • 所需范围或精度的知识。标准库函数必须处理标准指定的所有情况。如果应用程序只需要其中的一个子集,或者可能只需要一个近似结果,那么也许可以进行优化。
  • 对计算进行数学简化或以智能方式组合一些计算步骤,以便可以为该组合进行有效实施。
于 2022-03-03T14:13:23.903 回答
1

这是二进制搜索的另一种替代方法。它可能没有那么快std::sqrt,还没有测试过。但它肯定会比你的二进制搜索更快。

auto
Sqrt(float x)
{
    using F = decltype(x);
    if (x == 0 || x == INFINITY || isnan(x))
        return x;
    if (x < 0)
        return F{NAN};
    int e;
    x = std::frexp(x, &e);
    if (e % 2 != 0)
    {
        ++e;
        x /= 2;
    }
    auto y = (F{-160}/567*x + F{2'848}/2'835)*x + F{155}/567;
    y = (y + x/y)/2;
    y = (y + x/y)/2;
    return std::ldexp(y, e/2);
}

在排除 +/-0、nan、inf 和负数之后,它通过将 分解为 [ 1 / 4 , 1) 乘以 2 efloat范围内的尾数来工作,其中是偶数。答案是 sqrt(mantissa)* 2 e / 2e

找到尾数的 sqrt 可以通过在 [ 1 / 4 , 1]范围内拟合的最小二乘二次曲线来猜测。然后通过牛顿-拉夫森的两次迭代来完善这个好的猜测。这将使您在正确舍入结果的1 ulp范围内。一个好的std::sqrt通常会得到最后一点正确。

于 2022-03-03T19:59:56.933 回答
0

我也尝试过https://en.wikipedia.org/wiki/Fast_inverse_square_root中提到的算法,但没有找到想要的结果,请检查

#include <math.h>
#include <iostream>
#include <chrono>

#include <bit>
#include <limits>
#include <cstdint>

using namespace std;
using namespace std::chrono;

#define CHECK_RANGE 10000

inline float msqrt(float a)
{
    int i;
    for (i = 0;i * i <= a;i++);
    
    float lb = i - 1; //lower bound
    if (lb * lb == a)
        return lb;
    float ub = lb + 1; // upper bound
    float pub = ub; // previous upper bound
    for (int j = 0;j <= 20;j++)
    {
        float ub2 = ub * ub;
        if (ub2 > a)
        {
            pub = ub;
            ub = (lb + ub) / 2; // mid value of lower and upper bound
        }
        else
        {
            lb = ub; 
            ub = pub;
        }
    }
    return ub;
}

/* mentioned here ->  https://en.wikipedia.org/wiki/Fast_inverse_square_root */
inline float Q_sqrt(float number)
{
    union Conv {
        float    f;
        uint32_t i;
    };
    Conv conv;
    conv.f= number;
    conv.i = 0x5f3759df - (conv.i >> 1);
    conv.f *= 1.5F - (number * 0.5F * conv.f * conv.f);
    return 1/conv.f;
}

void check_Qsqrt()
{
    for (size_t i = 0; i < CHECK_RANGE; i++)
    {
        Q_sqrt(i);
    }
}

void check_msqrt()
{
    for (size_t i = 0; i < CHECK_RANGE; i++)
    {
        msqrt(i);
    }
}

void check_sqrt()
{
    for (size_t i = 0; i < CHECK_RANGE; i++)
    {
        sqrt(i);
    }
}

int main()
{
    auto start1 = high_resolution_clock::now();
    check_msqrt();
    auto stop1 = high_resolution_clock::now();

    auto duration1 = duration_cast<microseconds>(stop1 - start1);
    cout << "Time for check_msqrt = " << duration1.count() << " micro secs\n";


    auto start2 = high_resolution_clock::now();
    check_sqrt();
    auto stop2 = high_resolution_clock::now();

    auto duration2 = duration_cast<microseconds>(stop2 - start2);
    cout << "Time for check_sqrt = " << duration2.count() << " micro secs\n";
    
    auto start3 = high_resolution_clock::now();
    check_Qsqrt();
    auto stop3 = high_resolution_clock::now();

    auto duration3 = duration_cast<microseconds>(stop3 - start3);
    cout << "Time for check_Qsqrt = " << duration3.count() << " micro secs\n";

    //cout << Q_sqrt(3);
    //cout << sqrt(3);
    //cout << msqrt(3);
    return 0;
}
于 2022-03-03T13:41:08.977 回答