0

我有一段代码使用 Numbapro 编写一个简单的内核来对两个大小为 41724 的数组的内容进行平方,将它们加在一起并将其存储到另一个数组中。所有数组都具有相同的大小并且是 float32。代码如下:

import numpy as np
from numba import *
from numbapro import cuda

@cuda.jit('void(float32[:],float32[:],float32[:])')
def square_add(a,b,c):
    tx = cuda.threadIdx.x
    bx = cuda.blockIdx.x
    bw = cuda.blockDim.x

    i = tx + bx * bw

    #Since the length of a is 41724 and the total
    #threads is 41*1024 = 41984, this check is necessary
    if (i>len(a)):
            return
    else:
            c[i] = a[i]*a[i] + b[i]*b[i]


a = np.array(range(0,41724),dtype = np.float32)
b = np.array(range(41724,83448),dtype=np.float32)
c = np.zeros(shape=(1,41724),dtype=np.float32)

d_a = cuda.to_device(a)
d_b = cuda.to_device(b)
d_c = cuda.to_device(c,copy=False)

#Launch the kernel; Gridsize = (1,41),Blocksize=(1,1024)
square_add[(1,41),(1,1024)](d_a,d_b,d_c)

c = d_c.copy_to_host()
print c
print len(c[0])

打印操作结果(数组 c)时得到的值与在 python 终端中执行完全相同的操作时得到的值完全不同。我不知道我在这里做错了什么。

4

1 回答 1

1

这里有两个问题。

首先是您为 CUDA 内核启动指定了块和网格维度,这与您选择在内核中使用的索引方案不兼容。

这个:

square_add[(1,41),(1,1024)](d_a,d_b,d_c)

启动一个二维网格,其中所有线程在 x 中具有相同的块和线程尺寸,并且仅在 y 中不同。这意味着

tx = cuda.threadIdx.x
bx = cuda.blockIdx.x
bw = cuda.blockDim.x

i = tx + bx * bw

将为i=0每个线程产生。如果将内核启动更改为:

square_add[(41,1),(1024,1)](d_a,d_b,d_c)

你会发现 in indexing 会正常工作。

第二个是c被声明为二维数组,但内核函数签名被声明为一维数组。在某些情况下,numbapro 运行时应该检测到这一点并引发错误。

我能够让您的示例正常工作,如下所示:

import numpy as np
from numba import *
from numbapro import cuda

@cuda.jit('void(float32[:],float32[:],float32[:,:])')
def square_add(a,b,c):
    tx = cuda.threadIdx.x
    bx = cuda.blockIdx.x
    bw = cuda.blockDim.x

    i = tx + bx * bw

    if (i<len(a)):
        c[0,i] = a[i]*a[i] + b[i]*b[i]

a = np.array(range(0,41724),dtype=np.float32)
b = np.array(range(41724,83448),dtype=np.float32)
c = np.zeros(shape=(1,41724),dtype=np.float32)

d_a = cuda.to_device(a)
d_b = cuda.to_device(b)
d_c = cuda.to_device(c, copy=False)

square_add[(41,1),(1024,1)](d_a,d_b,d_c)

c = d_c.copy_to_host()
print(c)
print(c.shape)

[注意我使用的是 Python 3,所以这使用了新样式的打印语句]

$ ipython numbatest.py 
numbapro:1: ImportWarning: The numbapro package is deprecated in favour of the accelerate package. Please update your code to use equivalent functions from accelerate.
[[  1.74089216e+09   1.74097562e+09   1.74105907e+09 ...,   8.70371021e+09
    8.70396006e+09   8.70421094e+09]]
(1, 41724)
于 2016-04-04T08:19:59.730 回答