很难理解数组(维度)是如何在 cublas 中组织的。做了以下测试,但输出无法解释。谢谢你的帮助!
include <stdio.h>
include <stdlib.h>
include <cublas.h>
define DIMX 5
define DIMY 5
define ROW 2
define COL 3
typedef int TYPE;
void print_matrix(TYPE * v)
{
int i,j;
for (i=0; i<DIMX; i++)
{
for (j=0; j<DIMY; j++) printf("%5d ",v[i*DIMY+j]);
printf("\n");
}
}
int main()
{
printf("Hello world!\n");
int i;
//Initialize the array
TYPE v[DIMX*DIMY];
for (i=0; i<DIMX*DIMY; i++) v[i]=i+1;
printf("Before:\n");
print_matrix(v);
//Cublas part
cublasInit();
int *cv;
cublasAlloc(DIMX*DIMY,sizeof(TYPE),(void**)&cv);
cublasSetMatrix(ROW,COL,sizeof(TYPE),v,DIMX,cv,DIMY);
//cublasGetVector(DIMX*DIMY,sizeof(TYPE),cv,1,v,1);
cublasGetVector(DIMX*DIMY,sizeof(TYPE),cv,DIMX,v,DIMX);
cublasFree(cv);
cublasShutdown();
printf("After:\n");
print_matrix(v);
return 0;
}
输出:
你好世界!之前:1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 之后:1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25