1

编辑:这里的 GitHub 问题:https ://github.com/google/jax/issues/5190

我正在尝试使用 jit 优化以下功能:

@partial(jit, static_argnums=(0, 1,))
def coocurrence_helper(pairs: np.array, label_map: Dict) -> lil_matrix:
    uniques = lil_matrix(np.zeros((len(label_map), len(label_map))).astype("int32"))
    for item in pairs:
        if item[0]!=item[1]:
            uniques[label_map[item[0]], label_map[item[1]]] += 1
    return uniques

上面的例程在这里使用:

def _get_pairwise_frequencies(
     data: pd.DataFrame, crosstab=False
    ) -> pd.DataFrame:
        values = data.stack()
        values.index = values.index.droplevel(1)
        values.name = "vals"
        values = optimize(values.to_frame())
        pair = optimize(values.join(values, rsuffix="_2"))
        label_map = dict()
        for lbl, each in enumerate(values.vals.unique()):
            label_map[each] = lbl
        if not crosstab:
            freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
            return ((freq / freq.sum(1).ravel()).astype(np.float32))
        else:
            freq = pd.crosstab(pair["vals"], pair["vals_2"])
            self.index = freq.index
            return csr_matrix((freq / freq.sum(1)).astype(np.float32))

但我收到以下错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-42-f8e638fc2bb6> in <module>
----> 1 _get_pairwise_frequencies(data)

<ipython-input-30-43adeb39c76c> in _get_pairwise_frequencies(data, crosstab)
     25             label_map[each] = lbl
     26         if not crosstab:
---> 27             freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
     28             return csr_matrix((freq / freq.sum(1).ravel()).astype(np.float32))
     29         else:

~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    369         return cache_miss(*args, **kwargs)[0]  # probably won't return
    370     else:
--> 371       return cpp_jitted_f(*args, **kwargs)
    372   f_jitted._cpp_jitted_f = cpp_jitted_f
    373 

ValueError: vector::reserve

这里问题的根源是什么?不使用static_argnums错误消息是

RuntimeError: Invalid argument: Unknown NumPy type O size 8

具有相同的回溯。

4

1 回答 1

1

问题是您返回的scipy.sparse.lil_matrix不是有效的 JAX 类型。JAXjit装饰器不能用作任意 Python 代码的编译器;它旨在优化 JAX 数组上的操作序列。

在这种情况下,最好的方法可能是@partial(jit, ...)从你的函数中删除装饰器;如果你想在这里使用 JAX jit 编译,你首先必须重写你的代码以避免scipy.sparse矩阵并使用 JAX 数组。

于 2020-12-15T14:40:30.823 回答