以下代码返回一个数组而不是预期的浮点值。
def f(x):
return x+1
f = np.vectorize(f, otypes=[np.float])
>>> f(10.5)
array(11.5)
如果输入是标量而不是奇怪的数组类型,有没有办法强制它返回简单的标量值?
我觉得很奇怪它默认不这样做,因为所有其他 ufunc (如 np.cos、np.sin 等)确实返回常规标量
编辑:这是有效的代码:
import numpy as np
import functools
def as_scalar_if_possible(func):
@functools.wraps(func) #this is here just to preserve signature
def wrapper(*args, **kwargs):
return func(*args, **kwargs)[()]
return wrapper
@as_scalar_if_possible
@np.vectorize
def f(x):
return x + 1
print(f(11.5)) # 打印 12.5