有一个现有的问题“3 个长整数的平均值”,它特别关注三个有符号整数的平均值的有效计算。
然而,无符号整数的使用允许额外的优化不适用于上一个问题所涵盖的场景。这个问题是关于三个无符号整数平均值的有效计算,其中平均值向零舍入,即在数学术语中我想计算 ⌊ (a + b + c) / 3 ⌋。
计算该平均值的一种直接方法是
avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;
首先,现代优化编译器会将除法转换为具有倒数加移位的乘法,并将模运算转换为反乘和减法,其中反乘可以使用许多架构上可用的scale_add习惯用法,例如lea
x86_64,add
在lsl #n
ARM 上,iscadd
在 NVIDIA GPU 上。
在尝试以适用于许多常见平台的通用方式优化上述内容时,我观察到整数运算的成本通常处于逻辑关系≤(add | sub)≤ shift ≤ scale_add ≤ mul。这里的成本是指所有延迟、吞吐量限制和功耗。当处理的整数类型比本机寄存器宽度宽时,任何此类差异都会变得更加明显,例如在uint64_t
32 位处理器上处理数据时。
因此,我的优化策略是尽量减少指令数,并在可能的情况下用“廉价”操作替换“昂贵”操作,同时不增加寄存器压力并为广泛的无序处理器保留可利用的并行性。
第一个观察结果是,我们可以通过首先应用产生一个和值和一个进位值的 CSA(进位保存加法器)将三个操作数的总和减少为两个操作数的总和,其中进位值的权重是总和的两倍价值。在大多数处理器上,基于软件的 CSA 的成本是 5 个逻辑s。一些处理器,比如 NVIDIA GPU,有一条LOP3
指令可以一举计算三个操作数的任意逻辑表达式,在这种情况下,CSA 会压缩为两个LOP3
s(注意:我还没有说服 CUDA 编译器发出这两个LOP3
s;它目前生产四个LOP3
s!)。
第二个观察是,因为我们正在计算除以 3 的模数,所以我们不需要反向乘法来计算它。我们可以改为使用dividend % 3
= ,将((dividend / 3) + dividend) & 3
模数减少为加法加逻辑,因为我们已经有了除法结果。这是通用算法的一个实例:股息 % (2 n -1) = ((股息 / (2 n -1) + 股息) & (2 n -1)。
最后,对于校正项中的除以 3,(a % 3 + b % 3 + c % 3) / 3
我们不需要通用除以 3 的代码。由于被除数非常小,在 [0, 6] 中,我们可以简化x / 3
为(3 * x) / 8
只需要scale_add加上shift的代码。
下面的代码显示了我当前正在进行的工作。使用 Compiler Explorer 检查为各种平台生成的代码显示了我期望的紧凑代码(使用 编译时-O3
)。
然而,在使用 Intel 13.x 编译器对我的 Ivy Bridge x86_64 机器上的代码进行计时时,一个缺陷变得明显:uint64_t
与简单版本相比,我的代码提高了延迟(数据从 18 个周期到 15 个周期),吞吐量变差了(从数据每 6.8 个周期一个结果到每 8.5 个周期一个结果uint64_t
)。更仔细地查看汇编代码很明显为什么会这样:我基本上设法将代码从大致三向并行度降低到大致双向并行度。
是否有一种普遍适用的优化技术,对常见的处理器特别是所有类型的 x86 和 ARM 以及 GPU 都有益,它可以保留更多的并行性?或者,是否有一种优化技术可以进一步减少总体操作数以弥补并行度的降低?校正项的计算(tail
在下面的代码中)似乎是一个很好的目标。简化(carry_mod_3 + sum_mod_3) / 2
看起来很诱人,但为九种可能的组合之一提供了不正确的结果。
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#define BENCHMARK (1)
#define SIMPLE_COMPUTATION (0)
#if BENCHMARK
#define T uint64_t
#else // !BENCHMARK
#define T uint8_t
#endif // BENCHMARK
T average_of_3 (T a, T b, T c)
{
T avg;
#if SIMPLE_COMPUTATION
avg = a / 3 + b / 3 + c / 3 + (a % 3 + b % 3 + c % 3) / 3;
#else // !SIMPLE_COMPUTATION
/* carry save adder */
T a_xor_b = a ^ b;
T sum = a_xor_b ^ c;
T carry = (a_xor_b & c) | (a & b);
/* here 2 * carry + sum = a + b + c */
T sum_div_3 = (sum / 3); // {MUL|MULHI}, SHR
T sum_mod_3 = (sum + sum_div_3) & 3; // ADD, AND
if (sizeof (size_t) == sizeof (T)) { // "native precision" (well, not always)
T two_carry_div_3 = (carry / 3) * 2; // MULHI, ANDN
T two_carry_mod_3 = (2 * carry + two_carry_div_3) & 6; // SCALE_ADD, AND
T head = two_carry_div_3 + sum_div_3; // ADD
T tail = (3 * (two_carry_mod_3 + sum_mod_3)) / 8; // ADD, SCALE_ADD, SHR
avg = head + tail; // ADD
} else {
T carry_div_3 = (carry / 3); // MUL, SHR
T carry_mod_3 = (carry + carry_div_3) & 3; // ADD, AND
T head = (2 * carry_div_3 + sum_div_3); // SCALE_ADD
T tail = (3 * (2 * carry_mod_3 + sum_mod_3)) / 8; // SCALE_ADD, SCALE_ADD, SHR
avg = head + tail; // ADD
}
#endif // SIMPLE_COMPUTATION
return avg;
}
#if !BENCHMARK
/* Test correctness on 8-bit data exhaustively. Should catch most errors */
int main (void)
{
T a, b, c, res, ref;
a = 0;
do {
b = 0;
do {
c = 0;
do {
res = average_of_3 (a, b, c);
ref = ((uint64_t)a + (uint64_t)b + (uint64_t)c) / 3;
if (res != ref) {
printf ("a=%08x b=%08x c=%08x res=%08x ref=%08x\n",
a, b, c, res, ref);
return EXIT_FAILURE;
}
c++;
} while (c);
b++;
} while (b);
a++;
} while (a);
return EXIT_SUCCESS;
}
#else // BENCHMARK
#include <math.h>
// A routine to give access to a high precision timer on most systems.
#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
LARGE_INTEGER t;
static double oofreq;
static int checkedForHighResTimer;
static BOOL hasHighResTimer;
if (!checkedForHighResTimer) {
hasHighResTimer = QueryPerformanceFrequency (&t);
oofreq = 1.0 / (double)t.QuadPart;
checkedForHighResTimer = 1;
}
if (hasHighResTimer) {
QueryPerformanceCounter (&t);
return (double)t.QuadPart * oofreq;
} else {
return (double)GetTickCount() * 1.0e-3;
}
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
struct timeval tv;
gettimeofday(&tv, NULL);
return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif
#define N (3000000)
int main (void)
{
double start, stop, elapsed = INFINITY;
int i, k;
T a, b;
T avg0 = 0xffffffff, avg1 = 0xfffffffe;
T avg2 = 0xfffffffd, avg3 = 0xfffffffc;
T avg4 = 0xfffffffb, avg5 = 0xfffffffa;
T avg6 = 0xfffffff9, avg7 = 0xfffffff8;
T avg8 = 0xfffffff7, avg9 = 0xfffffff6;
T avg10 = 0xfffffff5, avg11 = 0xfffffff4;
T avg12 = 0xfffffff2, avg13 = 0xfffffff2;
T avg14 = 0xfffffff1, avg15 = 0xfffffff0;
a = 0x31415926;
b = 0x27182818;
avg0 = average_of_3 (a, b, avg0);
for (k = 0; k < 5; k++) {
start = second();
for (i = 0; i < N; i++) {
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
avg0 = average_of_3 (a, b, avg0);
b = (b + avg0) ^ a;
a = (a ^ b) + avg0;
}
stop = second();
elapsed = fmin (stop - start, elapsed);
}
printf ("a=%016llx b=%016llx avg=%016llx",
(uint64_t)a, (uint64_t)b, (uint64_t)avg0);
printf ("\rlatency: each average_of_3() took %.6e seconds\n",
elapsed / 16 / N);
a = 0x31415926;
b = 0x27182818;
avg0 = average_of_3 (a, b, avg0);
for (k = 0; k < 5; k++) {
start = second();
for (i = 0; i < N; i++) {
avg0 = average_of_3 (a, b, avg0);
avg1 = average_of_3 (a, b, avg1);
avg2 = average_of_3 (a, b, avg2);
avg3 = average_of_3 (a, b, avg3);
avg4 = average_of_3 (a, b, avg4);
avg5 = average_of_3 (a, b, avg5);
avg6 = average_of_3 (a, b, avg6);
avg7 = average_of_3 (a, b, avg7);
avg8 = average_of_3 (a, b, avg8);
avg9 = average_of_3 (a, b, avg9);
avg10 = average_of_3 (a, b, avg10);
avg11 = average_of_3 (a, b, avg11);
avg12 = average_of_3 (a, b, avg12);
avg13 = average_of_3 (a, b, avg13);
avg14 = average_of_3 (a, b, avg14);
avg15 = average_of_3 (a, b, avg15);
b = (b + avg0) ^ a;
a = (a ^ b) + avg0;
}
stop = second();
elapsed = fmin (stop - start, elapsed);
}
printf ("a=%016llx b=%016llx avg=%016llx", (uint64_t)a, (uint64_t)b,
(uint64_t)(avg0 + avg1 + avg2 + avg3 + avg4 + avg5 + avg6 + avg7 +
avg8 + avg9 +avg10 +avg11 +avg12 +avg13 +avg14 +avg15));
printf ("\rthroughput: each average_of_3() took %.6e seconds\n",
elapsed / 16 / N);
return EXIT_SUCCESS;
}
#endif // BENCHMARK