2

我正在使用 mex 桥对 Matlab 中的稀疏矩阵执行一些操作。为此,我需要将输入矩阵转换为 CSR(压缩行存储)格式,因为 Matlab 将稀疏矩阵存储在 CSC(压缩列存储)中。

我能够得到值数组和 column_indices 数组。但是,我正在努力获取 CSR 格式的 row_pointer 数组。是否有任何 C 库可以帮助从 CSC 转换为 CSR?

此外,在编写 CUDA 内核时,使用 CSR 格式进行稀疏操作是否有效,或者我应该只使用以下数组:- 行索引、列索引和值?

哪个可以让我更好地控制数据,最大限度地减少自定义内核中的 for 循环数量?

4

3 回答 3

4

压缩行存储类似于压缩列存储,只是转置了。所以最简单的方法是在将矩阵传递给 MEX 文件之前使用 MATLAB 对其进行转置。然后,使用函数

Ap = mxGetJc(spA);
Ai = mxGetIr(spA);
Ax = mxGetPr(spA);

获取内部指针并将它们视为行存储。Ap 是行指针,Ai 是非零条目的列索引,Ax 是非零值。请注意,对于对称矩阵,您根本不需要做任何事情!CSC 和 CSR 是一样的。

使用哪种格式很大程度上取决于您以后要对矩阵做什么。例如,查看稀疏矩阵向量乘法的矩阵格式。那是经典论文之一,从那时起研究已经转移,因此您可以进一步环顾四周。

于 2012-09-05T09:08:23.257 回答
1

我最终使用 CUSP 库将 CSC 格式从 Matlab 转换为 CSR,如下所示。

A从 matlab获取矩阵后,我得到了它的row,colvalues向量,我将它们分别复制到thrust::host_vector为它们每个创建的中。

之后,我创建了两个cusp::array1d类型IndicesValues如下所示。

    typedef typename cusp::array1d<int,cusp::host_memory>Indices;   
    typedef typename cusp::array1d<float,cusp::host_memory>Values;
    Indices row_indices(rows.begin(),rows.end());
    Indices col_indices(cols.begin(),cols.end());
    Values  Vals(Val.begin(),Val.end());

在哪里rows,是我从 Matlab 得到的colsValthrust::host_vector

之后,我创建了一个cusp::coo_matrix_view如下所示的。

typedef cusp::coo_matrix_view<Indices,Indices,Values>HostView;
HostView Ah(m,n,NNZa,row_indices,col_indices,Vals);

其中m和是我从稀疏矩阵函数中n得到的参数。NNZamex

我将此视图矩阵复制到cusp::csr_matrix设备内存中,并设置了正确的尺寸,如下所示。

    cusp::csr_matrix<int,float,cusp::device_memory>CSR(m,n,NNZa);
    CSR = Ah;   

之后,我只是将这个 CSR 矩阵的三个单独的内容数组复制回主机,使用thrust::raw_pointer_castwhere 已经mxCalloc编辑了具有适当维度的数组,如下所示。

 cudaMemcpy(Acol,thrust::raw_pointer_cast(&CSR.column_indices[0]),sizeof(int)*(NNZa),cudaMemcpyDeviceToHost);
 cudaMemcpy(Aptr,thrust::raw_pointer_cast(&CSR.row_offsets[0]),sizeof(int)*(n+1),cudaMemcpyDeviceToHost);
 cudaMemcpy(Aval,thrust::raw_pointer_cast(&CSR.values[0]),sizeof(float)*(NNZa),cudaMemcpyDeviceToHost);

希望这对使用的任何人都有CUSPMatlab

于 2012-09-13T04:19:29.793 回答
0

你可以这样做:

n = size(M,1);
nz_num = nnz(M);
[col,rowi,vals] = find(M');
row = zeros(n+1,1);
ll = 1; row(1) = 1;
for l = 2:n
    if rowi(l)~=rowi(l-1)
        ll = ll + 1;
        row(ll) = l;
    end
end
row(n+1) = nz_num+1;`

它对我有用,希望它可以帮助别人!

于 2019-02-12T02:58:14.160 回答