3

我未能将布尔掩码保存为 Cython 类的属性。在实际代码中,我需要这个掩码来更有效地执行任务。下面是一个示例代码:

核心.pyx

import numpy as np
cimport numpy as np

cdef class MyClass():
    cdef public np.uint8_t[:] mask # uint8 has the same data structure of a boolean array
    cdef public np.float64_t[:] data

    def __init__(self, size):
        self.data = np.random.rand(size).astype(np.float64)
        self.mask = np.zeros(size, np.uint8)

脚本.py

import numpy as np
import pyximport
pyximport.install(setup_args={'include_dirs': np.get_include()})

from core import MyClass

mc = MyClass(1000000)
mc.mask = np.asarray(mc.data) > 0.5 

错误

当我运行script.py它成功编译 Cython,但抛出错误:

Traceback (most recent call last):
  File "script.py", line 8, in <module>
    mc.mask = np.asarray(mc.data) > 0.5
  File "core.pyx", line 6, in core.MyClass.mask.__set__
    cdef public np.uint8_t[:] mask
ValueError: Does not understand character buffer dtype format string ('?')

解决方法

我目前的解决方法是将掩码传递给我需要的所有函数,cast=True例如:

cpdef func(MyClass mc, np.ndarray[np.uint8_t, ndim=1, cast=True] mask):
    return np.asarray(mc.data)[mask]

问题

关于如何将面具存储在 Cython 类中是否有任何想法?

4

2 回答 2

2

所以我不相信内存视图实际上支持布尔索引。因此,要索引数组,您总是需要做

np.asarray(mc.data)[mask]
# or
mc.data.base[mask] # if you're sure it's always a view of something that supports boolean indexing)

我认为这不会随着@ead 提到的 Cython 更新而改变。我怀疑这样做的原因是赋值 ( mc.data[mask] = x) 可能相当容易,但是应该返回什么类型并不明显mc.data[mask]——它不是内存视图。

因此,无论您做什么,都会涉及一些混乱的代码。


对于分配给内存视图的部分,可以使用

mc.mask = (np.asarray(mc.data) > 0.5).view(np.uint8)

并将其返回到一个 Numpy bool 数组:

np.asarray(mc.mask).view(np.bool)

两者都不应该涉及复制。


如果是我设计这个,我会保持 memoryviews 非公开(仅供 Cython 使用)并具有仅保存 Python 接口的底层 Numpy 数组的普通对象属性。您可以使用property使它们保持同步(并进行强制转换):

cdef class MyClass:
    cdef np.uint8_t[:] mask_mview
    cdef object _mask

    @property
    def mask(self):
        return np.asarray(self._mask).view(np.bool)

    @mask.setter
    def mask(self, value):
        self._mask = value
        self.mask_view = value.view(np.uint8)

    # and the same for data

这样,您就可以将 memoryview 用于 memoryviews 擅长的事情(在 Cython 中逐个元素快速迭代),访问 Python 的普通 Numpy 数组,并且两者保持同步(至少通过 Python 接口)。

于 2019-10-06T13:02:23.553 回答
1

您最好的选择(如果您不想使用解决方法)可能是等待 Cython 0.29.14 发布。此问题已修复,可能会成为0.29.14的一部分。

以下最小示例

%%cython
import numpy as np
cimport numpy as np
cdef np.uint8_t[:] mask  = np.random.rand(20)>.5

将无法正常导入

ValueError:不理解字符缓冲区 dtype 格式字符串('?')

对于 Cython 0.29.13,但使用github(或 master)上 0.29.x-branch的当前状态。

于 2019-10-06T10:22:19.563 回答