我编写了一个 Cython 程序,调用 Intel MKL 进行矩阵乘法,目的是使其并行。它基于链接到 BLAS 的旧 SO 帖子,并使用了一堆我从未见过的 Cython 方法,但它可以正常工作,而且它比 NumPy(也链接到 MKL)慢得多。为了加快速度,我使用了典型的 Memoryview 格式(它使用ndarray np.float64_t数据类型进行几个操作)。但现在它不再使用double[::1]Memoryviews 工作。这是生成的错误: 'type cast': cannot convert from '__Pyx_memviewslice' to 'double *'

由于类型转换不起作用,MKL 函数只能看到 5 个参数中的 3 个: error C2660: 'cblas_ddot': function does not take 3 arguments

这是 .PYX 代码:

import numpy as np
cimport numpy as np
cimport cython
from cython cimport view
from cython.parallel cimport prange     #this is your OpenMP portion
from openmp cimport omp_get_max_threads #only used for getting the max # of threads on the machine 

cdef extern from "mkl_cblas.h" nogil: #import a function from Intel's MKL library
    double ddot "cblas_ddot"(int N,
                             double *X, 
                             int incX,
                             double *Y, 
                             int incY)

cpdef matmult(double[:,::1] A, double[:,::1] B):
    cdef int Ashape0=A.shape[0], Ashape1=A.shape[1], Bshape0=B.shape[0], Bshape1=B.shape[1], Arowshape0=A[0,:].shape[0] #these are defined here as they aren't allowed in a prange loop

    if Ashape1 != Bshape1:
        raise TypeError('Inner dimensions are not consistent!')

    cdef int i, j
    cdef double[:,::1] out = np.zeros((Ashape0, Bshape1))
    cdef double[::1] A_row = np.zeros(Ashape0)
    cdef double[:] B_col = np.zeros(Bshape1) #no idea why this is not allowed to be [::1]
    cdef int Arowstrides = A_row.strides[0] // sizeof(double)
    cdef int Bcolstrides = B_col.strides[0] // sizeof(double)
    cdef int maxthreads = omp_get_max_threads()

    for i in prange(Ashape0, nogil=True, num_threads=maxthreads, schedule='static'): # to use all cores

        A_row = A[i,:]
        for j in range(Bshape1):
            B_col = B[:,j]
            out[i,j] = ddot(Arowshape0, #call the imported Intel MKL library

return np.asarray(out)

我敢肯定,SO 上的某个人很容易指出这一点。如果您看到可以改进的地方,请告知 - 这是被黑客攻击和切碎的,我认为甚至不需要 i / j 循环。最干净的例子:https ://gist.github.com/JonathanRaiman/f2ce5331750da7b2d4e9我最终编译的实际上要快得多(2x),但没有给出任何结果,所以我将它放在另一篇文章中(这里:直接调用 BLAS / LAPACK使用 SciPy 界面和 Cython - 以及如何添加 MKL



