58

对于一些真正的价值bcin [-1, 1],我需要计算

sqrt( (b²*c²) / (1-c²) ) = (|b|*|c|) / sqrt((1-c)*(1+c))

c当接近 1 或 -1时,分母中出现灾难性抵消。平方根可能也无济于事。

我想知道是否可以在这里应用一个聪明的技巧来避免 c=1 和 c=-1 周围的困难区域?

4

2 回答 2

49

这种稳定性方面最有趣的部分是分母,sqrt(1 - c*c). 为此,您需要做的就是将其扩展为sqrt(1 - c) * sqrt(1 + c). 我不认为这真的有资格作为一个“聪明的把戏”,但这就是所需要的。

对于典型的二进制浮点格式(例如 IEEE 754 binary64,但其他常见格式应该表现得同样好,除了像双双格式这样不愉快的事情),如果c接近1then1 - c将被精确计算,由Sterbenz 引理,虽然1 + c没有任何稳定性问题。同样, if cis close to -1then1 + c将被精确计算,并且1 - c将被精确计算。平方根和乘法运算不会引入重大的新误差。

这是一个数值演示,在具有 IEEE 754 binary64 浮点数和正确舍入sqrt运算的机器上使用 Python。

让我们c接近(但小于)1

>>> c = float.fromhex('0x1.ffffffff24190p-1')
>>> c
0.9999999999

在这里我们必须小心一点:请注意,显示的十进制值0.999999999是的精确值的近似值c。确切的值显示在十六进制字符串的构造中,或以分数形式显示,562949953365017/562949953421312这正是我们关心获得良好结果的确切值。

表达式 的精确值,sqrt(1 - c*c)在点后四舍五入到小数点后 100 位,是:

0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813

我使用 Python 的模块计算了这个,并使用Pari/GPdecimal仔细检查了结果。这是 Python 的计算:

>>> from decimal import Decimal, getcontext
>>> getcontext().prec = 1000
>>> good = (1 - Decimal(c) * Decimal(c)).sqrt().quantize(Decimal("1e-100"))
>>> print(good)
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813

如果我们天真地计算,我们会得到以下结果:

>>> from math import sqrt
>>> naive = sqrt(1 - c*c)
>>> naive
1.4142136208793713e-05

我们可以很容易地计算出 ulps 错误的近似数量(对正在进行的类型转换量表示歉意——float并且Decimal实例不能直接在算术运算中混合):

>>> from math import ulp
>>> float((Decimal(naive) - good) / Decimal(ulp(float(good))))
208701.28298527992

所以天真的结果是几十万 ulps - 粗略地说,我们已经失去了大约 5 个小数位的准确性。

现在让我们尝试扩展版本:

>>> better = sqrt(1 - c) * sqrt(1 + c)
>>> better
1.4142136208440158e-05
>>> float((Decimal(better) - good) / Decimal(ulp(float(good))))
-0.7170147200803595

所以在这里我们的准确度优于 1 ulp 错误。不完全正确地四舍五入,但下一个最好的事情。

通过更多的工作,应该可以说明并证明表达式中 ulps 误差数量的绝对上限sqrt(1 - c) * sqrt(1 + c),在域-1 < c < 1上,假设 IEEE 754 二进制浮点,舍入到偶数舍入模式,并在整个过程中进行正确的操作。我没有这样做,但如果这个上限超过 10 ulps,我会感到非常惊讶。

于 2021-01-13T18:07:32.100 回答
31

Mark Dickinson 为一般情况提供了一个很好的答案,我将用一种更专业的方法来补充它。

如今,许多计算环境都提供了一种称为融合乘加或简称 FMA 的操作,该操作是专门针对此类情况而设计的。在计算fma(a, b, c)完整乘积a * b(未截断和未舍入)时,会使用 进行加法运算c,然后在最后应用单个舍入。

目前出货的 GPU 和 CPU,包括基于 ARM64、x86-64 和 Power 架构的 GPU 和 CPU,通常包括 FMA 的快速硬件实现,它在 C 和 C++ 系列以及许多其他编程语言中作为标准公开数学函数fma()。一些(通常是较旧的)软件环境使用 FMA 的软件仿真,并且发现其中一些仿真存在缺陷。此外,这样的仿真往往很慢。

在 FMA 可用的情况下,表达式可以被评估为数值稳定且没有过早上溢和下溢的风险fabs (b * c) / sqrt (fma (c, -c, 1.0)),其中fabs()是浮点操作数的绝对值运算并sqrt()计算平方根。一些环境还提供倒数平方根运算,通常称为rsqrt(),在这种情况下,可能的替代方法是使用fabs (b * c) * rsqrt (fma (c, -c, 1.0))。使用rsqrt()避免了相对昂贵的除法,因此通常更快。但是, 的许多实现rsqrt()并没有像 那样正确舍入sqrt(),因此准确性可能会更差。

对下面代码的快速实验似乎表明基于 FMA 的表达式的最大误差约为 3 ulps,只要b正常的浮点数。我强调这并不能证明有任何错误限制。自动Herbie 工具,它试图找到给定浮点表达式的数值上有利的重写,建议使用fabs (b * c) * sqrt (1.0 / fma (c, -c, 1.0)). 然而,这似乎是一个虚假的结果,因为我既想不出任何特别的优势,也无法通过实验找到一个。

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>

#define USE_ORIGINAL  (0)
#define USE_HERBIE    (1)

/* function under test */
float func (float b, float c)
{
#if USE_HERBIE
     return fabsf (b * c) * sqrtf (1.0f / fmaf (c, -c, 1.0f));
#else USE_HERBIE
     return fabsf (b * c) / sqrtf (fmaf (c, -c, 1.0f));
#endif // USE_HERBIE
}

/* reference */
double funcd (double b, double c)
{
#if USE_ORIGINAL
    double b2 = b * b;
    double c2 = c * c;
    return sqrt ((b2 * c2) / (1.0 - c2));
#else
    return fabs (b * c) / sqrt (fma (c, -c, 1.0));
#endif
}

uint32_t float_as_uint32 (float a)
{
    uint32_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}

float uint32_as_float (uint32_t a)
{
    float r;
    memcpy (&r, &a, sizeof r);
    return r;
}

uint64_t double_as_uint64 (double a)
{
    uint64_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}

double floatUlpErr (float res, double ref)
{
    uint64_t i, j, err, refi;
    int expoRef;
    
    /* ulp error cannot be computed if either operand is NaN, infinity, zero */
    if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
        (res == 0.0f) || (ref == 0.0f)) {
        return 0.0;
    }
    /* Convert the float result to an "extended float". This is like a float
       with 56 instead of 24 effective mantissa bits.
    */
    i = ((uint64_t)float_as_uint32(res)) << 32;
    /* Convert the double reference to an "extended float". If the reference is
       >= 2^129, we need to clamp to the maximum "extended float". If reference
       is < 2^-126, we need to denormalize because of the float types's limited
       exponent range.
    */
    refi = double_as_uint64(ref);
    expoRef = (int)(((refi >> 52) & 0x7ff) - 1023);
    if (expoRef >= 129) {
        j = 0x7fffffffffffffffULL;
    } else if (expoRef < -126) {
        j = ((refi << 11) | 0x8000000000000000ULL) >> 8;
        j = j >> (-(expoRef + 126));
    } else {
        j = ((refi << 11) & 0x7fffffffffffffffULL) >> 8;
        j = j | ((uint64_t)(expoRef + 127) << 55);
    }
    j = j | (refi & 0x8000000000000000ULL);
    err = (i < j) ? (j - i) : (i - j);
    return err / 4294967296.0;
}

// Fixes via: Greg Rose, KISS: A Bit Too Simple. http://eprint.iacr.org/2011/007
static unsigned int z=362436069,w=521288629,jsr=362436069,jcong=123456789;
#define znew (z=36969*(z&0xffff)+(z>>16))
#define wnew (w=18000*(w&0xffff)+(w>>16))
#define MWC  ((znew<<16)+wnew)
#define SHR3 (jsr^=(jsr<<13),jsr^=(jsr>>17),jsr^=(jsr<<5)) /* 2^32-1 */
#define CONG (jcong=69069*jcong+13579)                     /* 2^32 */
#define KISS ((MWC^CONG)+SHR3)

#define N  (20)

int main (void)
{
    float b, c, errloc_b, errloc_c, res;
    double ref, err, maxerr = 0;
    
    c = -1.0f;
    while (c <= 1.0f) {
        /* try N random values of `b` per every value of `c` */
        for (int i = 0; i < N; i++) {
            /* allow only normals */
            do {
                b = uint32_as_float (KISS);
            } while (!isnormal (b));
            res = func (b, c);
            ref = funcd ((double)b, (double)c);
            err = floatUlpErr (res, ref);
            if (err > maxerr) {
                maxerr = err;
                errloc_b = b;
                errloc_c = c;
            }
        }
        c = nextafterf (c, INFINITY);
    }
#if USE_HERBIE
    printf ("HERBIE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#else // USE_HERBIE
    printf ("SIMPLE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#endif // USE_HERBIE
    
    return EXIT_SUCCESS;
}
于 2021-01-13T20:13:00.057 回答