2

我是 C 的初学者。我试图编写一些代码来使用转置执行矩阵乘法。有什么办法可以在执行时间方面改进代码?

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <assert.h>
#include <time.h>

int main()
{   

  int a[3][3] = {{1,0, 1}, {2, 2, 4},{1, 2, 3}};

        int b[3][3] ={ { 2, 3, 1}, { 6, 6, 2 }, { 9, 9, 0 } };
        int result[3][3];
        double tmp;
        int i,j,k;
        for (i=0; i<3; i++) //i = col
          {
            for (k=0; k<3; k++)
            {
              tmp = a[i][k];
              for (j=0; j<3; j++) //j = row
              {
                result[i][j] += tmp * b[k][j];
                printf("%d\t",result[i][j]);
              }
            }
          }
}
4

4 回答 4

3

如果你的矩阵是int,你真的不应该使用 adouble作为临时的。毫无目的地将整数转换为浮点数然后再转换回来是非常浪费的,它可能会花费很多。

于 2012-10-09T07:48:56.097 回答
2

由于多种原因,您的矩阵乘法实现是错误的。矩阵乘法是通过计算第一个矩阵的每一行与第二个矩阵的每一列的内积来执行的,这在您的实现中基本上被遗漏了。您正在使用指向a[i][k] 的temp变量,该变量在最内层循环中保持不变。第一个矩阵的行索引和第二个矩阵的列索引(反之亦然)必须在实际乘法步骤中更新。此外,结果会逐渐添加到第三个矩阵中,每个元素都必须在 C 等语言中用 0 初始化,以避免垃圾值的问题。

于 2012-10-09T10:34:01.000 回答
1

要尝试的一件非常不直观的事情是对源代码进行反优化并消除显式tmp

for (i=0; i<3; i++)
    for (k=0; k<3; k++)
        for (j=0; j<3; j++) //j = row
        {
            result[i][j] += a[i][k] * b[k][j];
        }

这在某种程度上解开了编译器的束缚,并允许它自己找到常见的不变子表达式。它将它们移到循环之外,也许还使用更快的范例(寄存器而不是堆栈位置)来保存它。

根据目标 CPU,启用速度优化的精明编译器可能能够通过分配独立寄存器和展开内部循环来并行 CPU 的流水线。当然,这一切都取决于您指示编译器进行优化(使用适当的编译器选项)。

于 2012-10-09T07:58:08.963 回答
0

您是否尝试过使用结构?

这是一个 4 x 4 矩阵乘法,例如。将其打开 3 x 3 很容易。

typedef struct s_matrix

{

double    a;
double    b;
double    c;
double    d;

double    e;
double    f;
double    g;
double    h;

double    i;
double    j;
double    k;
double    l;

double    m;
double    n;
double    o;
double    p;

} t_matrix;

t_matrix *m4_mul_m4(t_matrix *b, t_matrix *a)

{ t_matrix *m;

m = malloc(sizeof(t_matrix));
m->a = a->a * b->a + a->b * b->e + a->c * b->i + a->d * b->m;
m->b = a->a * b->b + a->b * b->f + a->c * b->j + a->d * b->n;
m->c = a->a * b->c + a->b * b->g + a->c * b->k + a->d * b->o;
m->d = a->a * b->d + a->b * b->h + a->c * b->l + a->d * b->p;
m->e = a->e * b->a + a->f * b->e + a->g * b->i + a->h * b->m;
m->f = a->e * b->b + a->f * b->f + a->g * b->j + a->h * b->n;
m->g = a->e * b->c + a->f * b->g + a->g * b->k + a->h * b->o;
m->h = a->e * b->d + a->f * b->h + a->g * b->l + a->h * b->p;
m->i = a->i * b->a + a->j * b->e + a->k * b->i + a->l * b->m;
m->j = a->i * b->b + a->j * b->f + a->k * b->j + a->l * b->n;
m->k = a->i * b->c + a->j * b->g + a->k * b->k + a->l * b->o;
m->l = a->i * b->d + a->j * b->h + a->k * b->l + a->l * b->p;
m->m = a->m * b->a + a->n * b->e + a->o * b->i + a->p * b->m;
m->n = a->m * b->b + a->n * b->f + a->o * b->j + a->p * b->n;
m->o = a->m * b->c + a->n * b->g + a->o * b->k + a->p * b->o;
m->p = a->m * b->d + a->n * b->h + a->o * b->l + a->p * b->p;
return (m);

}

于 2014-12-20T20:52:31.013 回答