我遇到了一个 scipy 函数,无论传递给它什么,它似乎都会返回一个 numpy 数组。在我的应用程序中,我只需要能够传递标量和列表,因此唯一的“问题”是当我将标量传递给函数时,会返回一个包含一个元素的数组(当我期望一个标量时)。我应该忽略这种行为,还是修改函数以确保在传递标量时返回标量?
示例代码:
#! /usr/bin/env python
import scipy
import scipy.optimize
from numpy import cos
# This a some function we want to compute the inverse of
def f(x):
y = x + 2*cos(x)
return y
# Given y, this returns x such that f(x)=y
def f_inverse(y):
# This will be zero if f(x)=y
def minimize_this(x):
return y-f(x)
# A guess for the solution is required
x_guess = y
x_optimized = scipy.optimize.fsolve(minimize_this, x_guess) # THE PROBLEM COMES FROM HERE
return x_optimized
# If I call f_inverse with a list, a numpy array is returned
print f_inverse([1.0, 2.0, 3.0])
print type( f_inverse([1.0, 2.0, 3.0]) )
# If I call f_inverse with a tuple, a numpy array is returned
print f_inverse((1.0, 2.0, 3.0))
print type( f_inverse((1.0, 2.0, 3.0)) )
# If I call f_inverse with a scalar, a numpy array is returned
print f_inverse(1.0)
print type( f_inverse(1.0) )
# This is the behaviour I expected (scalar passed, scalar returned).
# Adding [0] on the return value is a hackey solution (then thing would break if a list were actually passed).
print f_inverse(1.0)[0] # <- bad solution
print type( f_inverse(1.0)[0] )
在我的系统上,这个输出是:
[ 2.23872989 1.10914418 4.1187546 ]
<type 'numpy.ndarray'>
[ 2.23872989 1.10914418 4.1187546 ]
<type 'numpy.ndarray'>
[ 2.23872989]
<type 'numpy.ndarray'>
2.23872989209
<type 'numpy.float64'>
我正在使用 MacPorts 提供的 SciPy 0.10.1 和 Python 2.7.3。
解决方案
阅读下面的答案后,我选择了以下解决方案。将返回行替换为f_inverse
:
if(type(y).__module__ == np.__name__):
return x_optimized
else:
return type(y)(x_optimized)
这里return type(y)(x_optimized)
导致返回类型与调用函数的类型相同。不幸的是,如果 y 是 numpy 类型,则这不起作用,因此if(type(y).__module__ == np.__name__)
用于使用此处提出的想法检测 numpy 类型并将它们从类型转换中排除。