0

我对 CUDA 很陌生。

我正在尝试实例化许多粒子并对它们进行一些物理处理。

import numba as nb
from numba import cuda
from numba.experimental import jitclass
import numpy as np

spec = [('position', nb.float32[:,:,:]),('velocity', nb.float32[:,:,:]), ('mass', nb.float32)]
@jitclass(spec)
class Particle(object):
    def __init__(self, pos, vel, mass = 1):
        self.position = pos
        self.velocity = vel
        self.mass = mass
        
    def move(self, dt):
        self.position += self.velocity * dt

@cuda.jit
def create_particle(rng_states, out):
    thread_id = cuda.grid(1)
    pos_x = xoroshiro128p_uniform_float32(rng_states, thread_id)
    pos_y = xoroshiro128p_uniform_float32(rng_states, thread_id)
    pos_z = xoroshiro128p_uniform_float32(rng_states, thread_id)
    vel_x = xoroshiro128p_uniform_float32(rng_states, thread_id)
    vel_y = xoroshiro128p_uniform_float32(rng_states, thread_id)
    vel_z = xoroshiro128p_uniform_float32(rng_states, thread_id)
    out[thread_id] = Particle([pos_x, pos_y, pos_z], [vel_x, vel_y, vel_z])
    
def instantiate_particles(n):
    threadsperblock = 512
    blockspergrid = (n + (threadsperblock - 1)) // threadsperblock
    seed = 12345
    rng_states = create_xoroshiro128p_states(threadsperblock * blockspergrid, seed)
    out = np.zeros(n, dtype=np.object)
    create_particle[blockspergrid, threadsperblock](rng_states, out)

instantiate_particles(512)

上面的代码会导致这个错误,

TypingError:无法确定 <class 'numba.experimental.jitclass.base.JitClassType'> 的 Numba 类型

我认为这是由

out = np.zeros(n, dtype=np.object)

实例化许多对象的正确方法是什么?

4

0 回答 0