0
import numpy as np
from numba import njit

dt = np.dtype([('x', np.float64), ('y', np.float64)])

@njit
def f():
#    a = np.zeros(2, dtype=dt)         # this works
#    return a['x']
    b = np.array((0.5, 1.5), dtype=dt)    # this doesn't
#    return b['x']
f()

错误信息是

NotImplementedError: Cannot cast float64 to Record(x[type=float64;offset=0],y[type=float64;offset=8];16;False): %".69" = phi double [%".70", %"switch.0"], [%".72", %"switch.1"]

没有@jit 它可以正常工作。

我真正想要实现的是创建一个自定义 dtype 标量列表。我尝试了以下替代方案:

  • numpy 数组不适合,因为事先不知道元素的数量,
  • 一个 numpy 向量列表需要太多的数字索引,这是我试图避免的
  • dicts 列表(据说)占用太多内存
  • 与上面列出的任何其他变体相比,一个 numpy 的 jited 类数组太慢了。

更新: 到目前为止我能得到的最远点是:

dt = np.dtype([('x', np.float64), ('y', np.float64)])
@nb.njit
def f():
    a = np.array((0.5, 1.5))
    b = a.view(dt)
    return b.x
f()

array([0.5])

但它不是一个标量,它是一个大小为 1 的数组(有或没有 @jit)。

更新2:

Recfunctions 还没有被 numba 覆盖。

from numpy.lib import recfunctions
from numba import njit
dt = np.dtype([('x', np.float64), ('y', np.float64)])
@njit
def f():
    a = np.array((1,2))
    b = recfunctions.unstructured_to_structured(a, dt)
    return b['x']
f()

Unknown attribute 'unstructured_to_structured' of type 
Module(<module 'numpy.lib.recfunctions'
4

1 回答 1

1

显然numba还没有完全实现numpy结构化数组功能。该错误表明将元组中的值分配给定义的数组时出现问题。

玩了一会儿后,我发现这行得通:

In [399]: dt = np.dtype([('x', np.float64),('y', np.float64)])                  
In [400]: @numba.njit 
     ...: def nf(vals, dt): 
     ...:     b = np.zeros((), dtype=dt) 
     ...:     b['x'][...] = vals[0] 
     ...:     b['y'][...] = vals[1] 
     ...:     return b 
     ...:                                                                       
In [401]: nf((.5,1.5),dt)                                                       
Out[401]: array((0.5, 1.5), dtype=[('x', '<f8'), ('y', '<f8')])

或制作一维数组:

In [405]: @numba.njit 
     ...: def nf1(n, x, y , dt): 
     ...:     b = np.zeros(n, dtype=dt) 
     ...:     b['x'][...] = x 
     ...:     b['y'][...] = y 
     ...:     return b 
     ...:                                                                       
In [406]: nf1(3, np.arange(3), np.ones(3), dt)                                  
Out[406]: array([(0., 1.), (1., 1.), (2., 1.)], dtype=[('x', '<f8'), ('y', '<f8')])
于 2019-11-10T19:22:28.060 回答