21

对于名为“高性能计算”的课程的作业,我需要优化以下代码片段:

int foobar(int a, int b, int N)
{
    int i, j, k, x, y;
    x = 0;
    y = 0;
    k = 256;
    for (i = 0; i <= N; i++) {
        for (j = i + 1; j <= N; j++) {
            x = x + 4*(2*i+j)*(i+2*k);
            if (i > j){
               y = y + 8*(i-j);
            }else{
               y = y + 8*(j-i);
            }
        }
    }
    return x;
}

使用一些建议,我设法优化了代码(或者至少我是这么认为的),例如:

  1. 不断传播
  2. 代数简化
  3. 复制传播
  4. 公共子表达式消除
  5. 死代码消除
  6. 循环不变量去除
  7. 按位移位而不是乘法,因为它们更便宜。

这是我的代码:

int foobar(int a, int b, int N) {

    int i, j, x, y, t;
    x = 0;
    y = 0;
    for (i = 0; i <= N; i++) {
        t = i + 512;
        for (j = i + 1; j <= N; j++) {
            x = x + ((i<<3) + (j<<2))*t;
        }
    }
    return x;
}

根据我的导师的说法,一个经过良好优化的代码指令在汇编语言级别应该有更少或更少成本的指令。因此必须运行,指令比原始代码更短的时间,即计算是用:

执行时间 = 指令数 * 每条指令的周期数

当我使用命令生成汇编代码时:gcc -o code_opt.s -S foobar.c

尽管进行了一些优化,但生成的代码比原始代码行数多,运行时间更低,但没有原始代码那么多。我究竟做错了什么?

不要粘贴汇编代码,因为两者都非常广泛。所以我在 main 中调用函数“foobar”,我正在使用 linux 中的 time 命令测量执行时间

int main () {
    int a,b,N;

    scanf ("%d %d %d",&a,&b,&N);
    printf ("%d\n",foobar (a,b,N));
    return 0;
}
4

7 回答 7

23

最初:

for (i = 0; i <= N; i++) {
    for (j = i + 1; j <= N; j++) {
        x = x + 4*(2*i+j)*(i+2*k);
        if (i > j){
           y = y + 8*(i-j);
        }else{
           y = y + 8*(j-i);
        }
    }
}

删除y计算:

for (i = 0; i <= N; i++) {
    for (j = i + 1; j <= N; j++) {
        x = x + 4*(2*i+j)*(i+2*k);
    }
}

拆分i, j, k:

for (i = 0; i <= N; i++) {
    for (j = i + 1; j <= N; j++) {
        x = x + 8*i*i + 16*i*k ;                // multiple of  1  (no j)
        x = x + (4*i + 8*k)*j ;                 // multiple of  j
    }
}

将它们移到外部(并删除运行N-i时间的循环):

for (i = 0; i <= N; i++) {
    x = x + (8*i*i + 16*i*k) * (N-i) ;
    x = x + (4*i + 8*k) * ((N*N+N)/2 - (i*i+i)/2) ;
}

重写:

for (i = 0; i <= N; i++) {
    x = x +         ( 8*k*(N*N+N)/2 ) ;
    x = x +   i   * ( 16*k*N + 4*(N*N+N)/2 + 8*k*(-1/2) ) ;
    x = x +  i*i  * ( 8*N + 16*k*(-1) + 4*(-1/2) + 8*k*(-1/2) );
    x = x + i*i*i * ( 8*(-1) + 4*(-1/2) ) ;
}

重写 - 重新计算:

for (i = 0; i <= N; i++) {
    x = x + 4*k*(N*N+N) ;                            // multiple of 1
    x = x +   i   * ( 16*k*N + 2*(N*N+N) - 4*k ) ;   // multiple of i
    x = x +  i*i  * ( 8*N - 20*k - 2 ) ;             // multiple of i^2
    x = x + i*i*i * ( -10 ) ;                        // multiple of i^3
}

另一个转移到外部(并删除 i 循环):

x = x + ( 4*k*(N*N+N) )              * (N+1) ;
x = x + ( 16*k*N + 2*(N*N+N) - 4*k ) * ((N*(N+1))/2) ;
x = x + ( 8*N - 20*k - 2 )           * ((N*(N+1)*(2*N+1))/6);
x = x + (-10)                        * ((N*N*(N+1)*(N+1))/4) ;

上述两个循环删除都使用求和公式:

Sum(1, i = 0..n) = n+1
Sum(i 1 , i = 0..n) = n(n + 1)/2
Sum(i 2 , i = 0..n) = n (n + 1)(2n + 1)/6
Sum(i 3 , i = 0..n) = n 2 (n + 1) 2 /4

于 2012-11-25T22:46:37.757 回答
22

y不影响代码的最终结果 - 删除:

int foobar(int a, int b, int N)
{
    int i, j, k, x, y;
    x = 0;
    //y = 0;
    k = 256;
    for (i = 0; i <= N; i++) {
        for (j = i + 1; j <= N; j++) {
            x = x + 4*(2*i+j)*(i+2*k);
            //if (i > j){
            //   y = y + 8*(i-j);
            //}else{
            //   y = y + 8*(j-i);
            //}
        }
    }
    return x;
}

k只是一个常数:

int foobar(int a, int b, int N)
{
    int i, j, x;
    x = 0;
    for (i = 0; i <= N; i++) {
        for (j = i + 1; j <= N; j++) {
            x = x + 4*(2*i+j)*(i+2*256);
        }
    }
    return x;
}

内部表达式可以转换为:x += 8*i*i + 4096*i + 4*i*j + 2048*j. 使用数学将它们全部推到外循环:x += 8*i*i*(N-i) + 4096*i*(N-i) + 2*i*(N-i)*(N+i+1) + 1024*(N-i)*(N+i+1).

您可以扩展上面的表达式,并应用平方和和立方和公式来获得一个封闭形式的表达式,它应该比双重嵌套循环运行得更快。我把它作为练习留给你。这样一来,ij将被删除。

a如果可能,也b应该将其删除 - 因为ab作为参数提供,但从未在您的代码中使用。

平方和和立方和公式:

  • Sum(x 2 , x = 1..n) = n(n + 1)(2n + 1)/6
  • 总和(x 3 , x = 1..n) = n 2 (n + 1) 2 /4
于 2012-11-25T22:15:53.953 回答
20

此函数等价于以下公式,其中仅包含4 次整数乘法1 次整数除法

x = N * (N + 1) * (N * (7 * N + 8187) - 2050) / 6;

为此,我只需将嵌套循环计算的总和输入Wolfram Alpha

sum (sum (8*i*i+4096*i+4*i*j+2048*j), j=i+1..N), i=0..N

是解决方案的直接链接。编码前三思。有时你的大脑可以比任何编译器更好地优化代码。

于 2012-11-25T22:50:16.357 回答
5

简要浏览第一个例程,您注意到的第一件事是涉及“y”的表达式完全未使用并且可以删除(就像您所做的那样)。这进一步允许消除 if/else(就像您所做的那样)。

剩下的是两个for循环和凌乱的表达。j下一步是分解出该表达式中不依赖的部分。您删除了一个这样的表达式,但(i<<3)(ie, i * 8) 仍保留在内循环中,并且可以删除。

Pascal 的回答提醒我,您可以使用循环步幅优化。首先移出(i<<3) * t内部循环(调用它i1),然后在初始化循环时计算一个j1等于的值(i<<2) * t。在每次迭代中递增j14 * t这是一个预先计算的常数)。用 替换你的内心表达x = x + i1 + j1;

有人怀疑可能有某种方法可以将两个循环组合成一个循环,但我并没有立即看到它。

于 2012-11-25T22:05:14.417 回答
2

其他一些我可以看到的东西。你不需要y,所以你可以删除它的声明和初始化。

此外,传入的值ab并没有实际使用,因此您可以将它们用作局部变量而不是xand t

此外,i您可以注意不是每次都添加到 512 ,而是t从 512 开始,每次迭代递增 1。

int foobar(int a, int b, int N) {
    int i, j;
    a = 0;
    b = 512;
    for (i = 0; i <= N; i++, b++) {
        for (j = i + 1; j <= N; j++) {
            a = a + ((i<<3) + (j<<2))*b;
        }
    }
    return a;
}

一旦达到这一点,您还可以观察到,除了初始化j,i并且j仅在单个多个中使用 each - i<<3and j<<2。我们可以直接在循环逻辑中对此进行编码,因此:

int foobar(int a, int b, int N) {
    int i, j, iLimit, jLimit;
    a = 0;
    b = 512;
    iLimit = N << 3;
    jLimit = N << 2;
    for (i = 0; i <= iLimit; i+=8) {
        for (j = i >> 1 + 4; j <= jLimit; j+=4) {
            a = a + (i + j)*b;
        }
        b++;
    }
    return a;
}
于 2012-11-25T22:03:57.520 回答
2

好的......所以这是我的解决方案,以及解释我做了什么以及如何做的内联评论。

int foobar(int N)
{ // We eliminate unused arguments 
    int x = 0, i = 0, i2 = 0, j, k, z;

    // We only iterate up to N on the outer loop, since the
    // last iteration doesn't do anything useful. Also we keep
    // track of '2*i' (which is used throughout the code) by a 
    // second variable 'i2' which we increment by two in every
    // iteration, essentially converting multiplication into addition.
    while(i < N) 
    {           
        // We hoist the calculation '4 * (i+2*k)' out of the loop
        // since k is a literal constant and 'i' is a constant during
        // the inner loop. We could convert the multiplication by 2
        // into a left shift, but hey, let's not go *crazy*! 
        //
        //  (4 * (i+2*k))         <=>
        //  (4 * i) + (4 * 2 * k) <=>
        //  (2 * i2) + (8 * k)    <=>
        //  (2 * i2) + (8 * 512)  <=>
        //  (2 * i2) + 2048

        k = (2 * i2) + 2048;

        // We have now converted the expression:
        //      x = x + 4*(2*i+j)*(i+2*k);
        //
        // into the expression:
        //      x = x + (i2 + j) * k;
        //
        // Counterintuively we now *expand* the formula into:
        //      x = x + (i2 * k) + (j * k);
        //
        // Now observe that (i2 * k) is a constant inside the inner
        // loop which we can calculate only once here. Also observe
        // that is simply added into x a total (N - i) times, so 
        // we take advantange of the abelian nature of addition
        // to hoist it completely out of the loop

        x = x + (i2 * k) * (N - i);

        // Observe that inside this loop we calculate (j * k) repeatedly, 
        // and that j is just an increasing counter. So now instead of
        // doing numerous multiplications, let's break the operation into
        // two parts: a multiplication, which we hoist out of the inner 
        // loop and additions which we continue performing in the inner 
        // loop.

        z = i * k;

        for (j = i + 1; j <= N; j++) 
        {
            z = z + k;          
            x = x + z;      
        }

        i++;
        i2 += 2;
    }   

    return x;
}

没有任何解释的代码归结为:

int foobar(int N)
{
    int x = 0, i = 0, i2 = 0, j, k, z;

    while(i < N) 
    {                   
        k = (2 * i2) + 2048;

        x = x + (i2 * k) * (N - i);

        z = i * k;

        for (j = i + 1; j <= N; j++) 
        {
            z = z + k;          
            x = x + z;      
        }

        i++;
        i2 += 2;
    }   

    return x;
}

我希望这有帮助。

于 2012-11-26T00:43:01.843 回答
0

int foobar(int N) //避免未使用的传递参数

{

int i, j, x=0;   //Remove unuseful variable, operation so save stack and Machine cycle

for (i = N; i--; )               //Don't check unnecessary comparison condition 

   for (j = N+1; --j>i; )

     x += (((i<<1)+j)*(i+512)<<2);  //Save Machine cycle ,Use shift instead of Multiply

return x;

}

于 2013-04-04T07:06:54.957 回答