2

对于 C 或 C++ 来说,这是一个有点棘手的问题。我在 Ubuntu 12.04.2 下运行 GCC 4.6.3。

我有一个p具有以下形式的三维张量的内存访问索引:

p = (i<<(2*N)) + (j<<N) + k

这里0 <= i,j,k < (1<<N)还有N一些正整数。

i>>S, j>>S, k>>S现在我想为with计算一个“缩小的”内存访问索引0 < S < N,它是:

q = ((i>>S)<<(2*(N-S))) + ((j>>S)<<(N-S)) + (k>>S)

最快的计算方法是什么q(事先p不知道i,j,k)?我们可以假设0 < N <= 10(即p是一个 32 位整数)。我会对N=8(即i,j,k8 位整数)的快速方法特别感兴趣。N并且S都是编译时常量。

N=8和的一个例子S=4

unsigned int p = 240407; // this is (3<<16) + (171<<8) + 23;
unsigned int q = 161; // this is (0<<8) + (10<<4) + 1
4

3 回答 3

1

直截了当的方式,8 个操作(其他是对常量的操作):

M = (1<<(N-S)) - 1;                     // A mask with S lowest bits.
q = (  ((p & (M<<(2*N+S))) >> (3*S))    // Mask 'i', shift to new position.
     + ((p & (M<<(  N+S))) >> (2*S))    // Likewise for 'j'.
     + ((p & (M<<     S))  >>    S));   // Likewise for 'k'.

看起来很复杂,但实际上并非如此,只是不容易(至少对我而言)让所有常量都正确。

为了创建具有较少操作的公式,我们观察到将数字U向左移动一位与乘以 相同1<<U。因此,由于乘法分布性,乘以等于((1<<U1) + (1<<U2) + ...)左移U1, U2, ... 并将所有内容相加。

因此,我们可以尝试屏蔽和的所需部分i,通过一次乘法将它们全部“移动”到相对于彼此的正确位置,然后将结果向右移动到最终目的地。这给了我们三个操作来计算。jkqp

不幸的是,有一些限制,特别是对于我们试图同时获得所有三个的情况。当我们将数字相加时(间接地,通过将​​多个乘数相加),我们必须确保只能在一个数字中设置位,否则我们会得到错误的结果。如果我们尝试一次添加(间接)三个正确移位的数字,我们有:

iiiii...........jjjjj...........kkkkk.......
 N-S      S      N-S      S      N-S
.....jjjjj...........kkkkk................
 N-S  N-S      S      N-S
..........kkkkk...............
 N-S  N-S  N-S

请注意,第二个和第三个数字的左侧是 和 的位ij但我们忽略它们。为此,我们假设乘法在 x86 上工作:将两种类型相乘T得到多个 type T,只有实际结果的最低位(如果没有溢出,则等于结果)

因此,为了确保k第三个数字中的位不与第一个数字中的位重叠j,我们需要它,3*(N-S) <= N即限制我们(移位后每个组件只有一个或两个位;不知道您是否曾经使用过精度低)。S >= 2*N/3N = 8S >= 6

但是,如果S >= 2*N/3,我们只能使用 3 个操作:

// Constant multiplier to perform three shifts at once.
F = (1<<(32-3*N)) + (1<<(32-3*N+S)) + (1<<(32-3*N+2*S));
// Mask, shift/combine with multipler, right shift to destination.
q = (((p & ((M<<(2*N+S)) + (M<<(N+S)) + (M<<S))) * F)
     >> (32-3*(N-S)));

如果 for 的约束S太严格(可能是这样),我们可以结合第一个和第二个公式:用第二种方法计算ik,然后j从第一个公式添加。在这里,我们需要以下数字中的位不重叠:

iiiii...............kkkkk.......
 N-S   S   N-S   S   N-S
..........kkkkk...............
 N-S  N-S  N-S

3*(N-S) <= 2*N,它给出了S >= N / 3,或者,N = 8更不严格S >= 3。公式如下:

// Constant multiplier to perform two shifts at once.
F = (1<<(32-3*N)) + (1<<(32-3*N+2*S));
// Mask, shift/combine with multipler, right shift to destination
// and then add 'j' from the straightforward formula.
q = ((((p & ((M<<(2*N+S)) + (M<<S))) * F) >> (32-3*(N-S)))
     + ((p & (M<<(N+S))) >> (2*S)));

此公式也适用于您的示例 where S = 4

这是否比直接方法更快取决于架构。另外,我不知道 C++ 是否保证假设的乘法溢出行为。最后,您需要确保值是无符号的并且正好是32 位,这样公式才能正常工作。

于 2013-08-04T19:07:48.847 回答
0

如果您不关心兼容性,对于 N = 8,您可以像这样得到 i、j、k:

 int p = .... 
 unsigned char *bytes = (char *)&p;

Now kis bytes[0], jisbytes[1]iis bytes[2](我在我的机器上发现了 little endian)。但我认为更好的方法是……。像那样(我们有 N_MASK = 2^N - 1)

 int q;
 q = ( p & N_MASK ) >> S;
 p >>= N;
 q |= ( ( p & N_MASK ) >> S ) << S;
 p >>= N;
 q |= ( ( p & N_MASK ) >> S ) << 2*S;
于 2013-07-31T17:43:39.143 回答
0

它符合您的要求吗?

#include <cstdint>
#include <iostream>

uint32_t to_q_from_p(uint32_t p, uint32_t N, uint32_t S)
{
   uint32_t mask = ~(~0 << N);
   uint32_t k = p &mask;
   uint32_t j = (p >> N)& mask;
   uint32_t i = (p >> 2*N)&mask;
   return ((i>>S)<<(2*(N-S))) + ((j>>S)<<(N-S)) + (k>>S);;
}

int main()
{
   uint32_t p = 240407;

   uint32_t q = to_q_from_p(p, 8, 4);

   std::cout << q << '\n';

}

如果你假设 N 总是 8 并且整数是小端的,那么它可以是

uint32_t to_q_from_p(uint32_t p, uint32_t S)
{
   auto ptr = reinterpret_cast<uint8_t*>(&p);
   return ((ptr[2]>>S)<<(2*(8-S))) + ((ptr[1]>>S)<<(8-S)) + (ptr[0]>>S);
}
于 2013-07-31T17:39:14.080 回答