1

我正在尝试使用 cython 来提高循环的性能,但是在声明输入类型时遇到了一些问题。

如何在我的类型化结构中包含一个字段,该字段可以是“前”或“后”的字符串

我有一个np.recarray如下所示的(注意,recarray 的长度在编译时是未知的)

import numpy as np
weights = np.recarray(4, dtype=[('a', np.int64),  ('b', np.str_, 5), ('c', np.float64)])
weights[0] = (0, "front", 0.5)
weights[1] = (0, "back", 0.5)
weights[2] = (1, "front", 1.0)
weights[3] = (1, "back", 0.0)

以及字符串列表的输入和pandas.Timestamp

import pandas as pd
ts = pd.Timestamp("2015-01-01")
contracts = ["CLX16", "CLZ16"]

我正在尝试对以下循环进行cythonize

def ploop(weights, contracts, timestamp):
    cwts = []
    for gen_num, position, weighting in weights:
        if weighting != 0:
            if position == "front":
                cntrct_idx = gen_num
            elif position == "back":
                cntrct_idx = gen_num + 1
            else:
                raise ValueError("transition.columns must contain "
                                 "'front' or 'back'")
            cwts.append((gen_num, contracts[cntrct_idx], weighting, timestamp))
    return cwts

我的尝试涉及weights在 cython 中将输入作为结构输入,在文件struct_test.pyx中如下

import numpy as np
cimport numpy as np


cdef packed struct tstruct:
    np.int64_t gen_num
    char[5] position
    np.float64_t weighting


def cloop(tstruct[:] weights_array, contracts, timestamp):
    cdef tstruct weights
    cdef int i
    cdef int cntrct_idx

    cwts = []
    for k in xrange(len(weights_array)):
        w = weights_array[k]
        if w.weighting != 0:
            if w.position == "front":
                cntrct_idx = w.gen_num
            elif w.position == "back":
                cntrct_idx = w.gen_num + 1
            else:
                raise ValueError("transition.columns must contain "
                                 "'front' or 'back'")
            cwts.append((w.gen_num, contracts[cntrct_idx], w.weighting,
                         timestamp))
    return cwts

但我收到运行时错误,我认为这与 char[5] position.

import pyximport
pyximport.install()
import struct_test

struct_test.cloop(weights, contracts, ts)

ValueError: Does not understand character buffer dtype format string ('w')

此外,我有点不清楚我将如何进行打字contracts以及timestamp.

4

1 回答 1

1

您的ploop(不带timestamp变量)产生:

In [226]: ploop(weights, contracts)
Out[226]: [(0, 'CLX16', 0.5), (0, 'CLZ16', 0.5), (1, 'CLZ16', 1.0)]

没有循环的等效函数:

def ploopless(weights, contracts):
    arr_contracts = np.array(contracts) # to allow array indexing
    wgts1 = weights[weights['c']!=0]
    mask = wgts1['b']=='front'
    wgts1['b'][mask] = arr_contracts[wgts1['a'][mask]]
    mask = wgts1['b']=='back'
    wgts1['b'][mask] = arr_contracts[wgts1['a'][mask]+1]
    return wgts1.tolist()

In [250]: ploopless(weights, contracts)
Out[250]: [(0, 'CLX16', 0.5), (0, 'CLZ16', 0.5), (1, 'CLZ16', 1.0)]

我正在利用返回的元组列表具有与输入weight数组相同的 (int, str, int) 布局这一事实。所以我只是复制并替换该字段weights的选定值。b

请注意,我使用的是字段选择索引mask。布尔值mask产生一个副本,所以我们必须小心索引顺序。

我猜想无循环数组版本将在时间上与cloop(在现实数组上)竞争。中的字符串和列表操作cloop可能会限制其加速。

于 2017-07-05T21:34:44.903 回答