我对 CUDA 很陌生。
我正在尝试实例化许多粒子并对它们进行一些物理处理。
import numba as nb
from numba import cuda
from numba.experimental import jitclass
import numpy as np
spec = [('position', nb.float32[:,:,:]),('velocity', nb.float32[:,:,:]), ('mass', nb.float32)]
@jitclass(spec)
class Particle(object):
def __init__(self, pos, vel, mass = 1):
self.position = pos
self.velocity = vel
self.mass = mass
def move(self, dt):
self.position += self.velocity * dt
@cuda.jit
def create_particle(rng_states, out):
thread_id = cuda.grid(1)
pos_x = xoroshiro128p_uniform_float32(rng_states, thread_id)
pos_y = xoroshiro128p_uniform_float32(rng_states, thread_id)
pos_z = xoroshiro128p_uniform_float32(rng_states, thread_id)
vel_x = xoroshiro128p_uniform_float32(rng_states, thread_id)
vel_y = xoroshiro128p_uniform_float32(rng_states, thread_id)
vel_z = xoroshiro128p_uniform_float32(rng_states, thread_id)
out[thread_id] = Particle([pos_x, pos_y, pos_z], [vel_x, vel_y, vel_z])
def instantiate_particles(n):
threadsperblock = 512
blockspergrid = (n + (threadsperblock - 1)) // threadsperblock
seed = 12345
rng_states = create_xoroshiro128p_states(threadsperblock * blockspergrid, seed)
out = np.zeros(n, dtype=np.object)
create_particle[blockspergrid, threadsperblock](rng_states, out)
instantiate_particles(512)
上面的代码会导致这个错误,
TypingError:无法确定 <class 'numba.experimental.jitclass.base.JitClassType'> 的 Numba 类型
我认为这是由
out = np.zeros(n, dtype=np.object)
实例化许多对象的正确方法是什么?