只要您的函数专门使用 ufunc(或类似的东西)来隐式循环输入值,Pierre GM 的答案就很好。如果您的函数需要迭代输入,那么np.asarray
就做得不够,因为您不能迭代 NumPy 标量:
import numpy as np
x = np.asarray(1)
for xval in x:
print(np.exp(xval))
Traceback (most recent call last):
File "Untitled 2.py", line 4, in <module>
for xval in x:
TypeError: iteration over a 0-d array
如果您的函数需要遍历输入,则可以使用np.atleast_1d
和np.squeeze
(请参阅数组操作例程 — NumPy 手册)。我包含了一个aaout
("Always Array OUT") 参数,因此您可以指定是否希望标量输入产生单元素数组输出;如果不需要,它可以被丢弃:
def scalar_or_iter_in(x, aaout=False):
"""
Gather function evaluations over scalar or iterable `x` values.
aaout :: boolean
"Always array output" flag: If True, scalar input produces
a 1-D, single-element array output. If false, scalar input
produces scalar output.
"""
x = np.asarray(x)
scalar_in = x.ndim==0
# Could use np.array instead of np.atleast_1d, as follows:
# xvals = np.array(x, copy=False, ndmin=1)
xvals = np.atleast_1d(x)
y = np.empty_like(xvals, dtype=float) # dtype in case input is ints
for i, xx in enumerate(xvals):
y[i] = np.exp(xx) # YOUR OPERATIONS HERE!
if scalar_in and not aaout:
return np.squeeze(y)
else:
return y
print(scalar_or_iter_in(1.))
print(scalar_or_iter_in(1., aaout=True))
print(scalar_or_iter_in([1,2,3]))
2.718281828459045
[2.71828183]
[ 2.71828183 7.3890561 20.08553692]
当然,对于求幂,您不应该像这里那样显式迭代,但是使用 NumPy ufunc 可能无法表达更复杂的操作。如果您不需要迭代,但希望对标量输入是否产生单元素数组输出进行类似控制,则函数的中间部分可能更简单,但返回必须处理np.atleast_1d
:
def scalar_or_iter_in(x, aaout=False):
"""
Gather function evaluations over scalar or iterable `x` values.
aaout :: boolean
"Always array output" flag: If True, scalar input produces
a 1-D, single-element array output. If false, scalar input
produces scalar output.
"""
x = np.asarray(x)
scalar_in = x.ndim==0
y = np.exp(x) # YOUR OPERATIONS HERE!
if scalar_in and not aaout:
return np.squeeze(y)
else:
return np.atleast_1d(y)
我怀疑在大多数情况下,该aaout
标志不是必需的,并且您总是希望标量输出和标量输入。在这种情况下,回报应该是:
if scalar_in:
return np.squeeze(y)
else:
return y