编辑:这里的 GitHub 问题:https ://github.com/google/jax/issues/5190
我正在尝试使用 jit 优化以下功能:
@partial(jit, static_argnums=(0, 1,))
def coocurrence_helper(pairs: np.array, label_map: Dict) -> lil_matrix:
uniques = lil_matrix(np.zeros((len(label_map), len(label_map))).astype("int32"))
for item in pairs:
if item[0]!=item[1]:
uniques[label_map[item[0]], label_map[item[1]]] += 1
return uniques
上面的例程在这里使用:
def _get_pairwise_frequencies(
data: pd.DataFrame, crosstab=False
) -> pd.DataFrame:
values = data.stack()
values.index = values.index.droplevel(1)
values.name = "vals"
values = optimize(values.to_frame())
pair = optimize(values.join(values, rsuffix="_2"))
label_map = dict()
for lbl, each in enumerate(values.vals.unique()):
label_map[each] = lbl
if not crosstab:
freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
return ((freq / freq.sum(1).ravel()).astype(np.float32))
else:
freq = pd.crosstab(pair["vals"], pair["vals_2"])
self.index = freq.index
return csr_matrix((freq / freq.sum(1)).astype(np.float32))
但我收到以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-42-f8e638fc2bb6> in <module>
----> 1 _get_pairwise_frequencies(data)
<ipython-input-30-43adeb39c76c> in _get_pairwise_frequencies(data, crosstab)
25 label_map[each] = lbl
26 if not crosstab:
---> 27 freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
28 return csr_matrix((freq / freq.sum(1).ravel()).astype(np.float32))
29 else:
~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
369 return cache_miss(*args, **kwargs)[0] # probably won't return
370 else:
--> 371 return cpp_jitted_f(*args, **kwargs)
372 f_jitted._cpp_jitted_f = cpp_jitted_f
373
ValueError: vector::reserve
这里问题的根源是什么?不使用static_argnums
错误消息是
RuntimeError: Invalid argument: Unknown NumPy type O size 8
具有相同的回溯。