我写了一个struct
和一些包装“CUBLAS 矩阵对象”的函数
struct
是:
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#define uint unsigned int
typedef struct {
uint rows;
uint cols;
float* devPtrvals;
} matrix;
alloc 函数创建矩阵结构:
matrix* matrix_alloc(uint rows, uint cols)
{
cudaError_t cudaStat;
matrix* w = malloc(sizeof(matrix));
w->rows = rows;
w->cols = cols;
cudaStat = cudaMalloc((void**)&w->devPtrvals, sizeof(float) * rows * cols);
if(cudaStat != cudaSuccess) {
fprintf(stderr, "device memory allocation failed\n");
return NULL;
}
return w;
};
免费功能:
uint matrix_free(matrix* w)
{
cudaFree(w->devPtrvals);
free(w);
return 1;
};
从浮点数组中设置矩阵值的函数:
uint matrix_set_vals(matrix* w, float* vals)
{
cublasStatus_t stat;
stat = cublasSetMatrix(w->rows, w->cols, sizeof(float),
vals, w->rows, w->devPtrvals, w->rows);
if(stat != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "data upload failed\n");
return 0;
}
return 1;
};
我在编写一个涵盖矩阵转置的通用点积函数时遇到了问题。这是我写的:
matrix* matrix_dot(cublasHandle_t handle, char transA, char transB,
float alpha, matrix* v, matrix* w, float beta)
{
matrix* x = matrix_alloc(transA == CUBLAS_OP_N ? v->rows : v->cols,
transB == CUBLAS_OP_N ? w->cols : w->rows);
//cublasStatus_t cublasSgemm(cublasHandle_t handle,
// cublasOperation_t transa, cublasOperation_t transb,
// int m, int n, int k,
// const float *alpha,
// const float *A, int lda,
// const float *B, int ldb,
// const float *beta,
// float *C, int ldc)
cublasSgemm(handle, transA, transB,
transA == CUBLAS_OP_N ? v->rows : v->cols,
transB == CUBLAS_OP_N ? w->cols : w->rows,
transA == CUBLAS_OP_N ? v->cols : v->rows,
&alpha, v->devPtrvals, v->rows, w->devPtrvals,
w->rows, &beta, x->devPtrvals, x->rows);
return x;
};
例子:
我想要一个矩阵A:
1 2 3
4 5 6
7 8 9
10 11 12
这意味着:
float* a = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
matrix* A = matrix_alloc(4, 3);
matrix_set_vals(A, a);
并将其与转置 B 相乘:
1 2 3
4 5 6
还:
float* b = {1, 2, 3, 4, 5, 6};
matrix* B = matrix_alloc(2, 3);
matrix_set_vals(B, b);
A*B^T=C 的结果:
14 32
32 77
50 122
68 167
我正在使用点函数:
matrix* C = matrix_dot(handle, CUBLAS_OP_N, CUBLAS_OP_T, 1.0, A, B, 0.0);
使用此功能时,我得到:** On entry to SGEMM parameter number 10 had an illegal value
我究竟做错了什么?