我遇到了一个scipy函数,无论传递给它什么,它似乎都返回一个numpy数组。在我的应用程序中,我只需要能够传递标量和列表,因此唯一的“问题”是当我将标量传递给函数时,将返回带有一个元素的数组(当我期望标量时)。我应该忽略这种行为,还是修改函数以确保在传递标量时返回标量?
示例代码:
#! /usr/bin/env python import scipy import scipy.optimize from numpy import cos # This a some function we want to compute the inverse of def f(x): y = x + 2*cos(x) return y # Given y, this returns x such that f(x)=y def f_inverse(y): # This will be zero if f(x)=y def minimize_this(x): return y-f(x) # A guess for the solution is required x_guess = y x_optimized = scipy.optimize.fsolve(minimize_this, x_guess) # THE PROBLEM COMES FROM HERE return x_optimized # If I call f_inverse with a list, a numpy array is returned print f_inverse([1.0, 2.0, 3.0]) print type( f_inverse([1.0, 2.0, 3.0]) ) # If I call f_inverse with a tuple, a numpy array is returned print f_inverse((1.0, 2.0, 3.0)) print type( f_inverse((1.0, 2.0, 3.0)) ) # If I call f_inverse with a scalar, a numpy array is returned print f_inverse(1.0) print type( f_inverse(1.0) ) # This is the behaviour I expected (scalar passed, scalar returned). # Adding [0] on the return value is a hackey solution (then thing would break if a list were actually passed). print f_inverse(1.0)[0] # <- bad solution print type( f_inverse(1.0)[0] )
在我的系统上,此输出为:
[ 2.23872989 1.10914418 4.1187546 ] <type 'numpy.ndarray'> [ 2.23872989 1.10914418 4.1187546 ] <type 'numpy.ndarray'> [ 2.23872989] <type 'numpy.ndarray'> 2.23872989209 <type 'numpy.float64'>
我正在使用MacPorts提供的SciPy 0.10.1和Python 2.7.3。
解
阅读以下答案后,我决定采用以下解决方案。将回车线替换为f_inverse:
if(type(y).__module__ == np.__name__): return x_optimized else: return type(y)(x_optimized)
这将return type(y)(x_optimized)导致返回类型与调用该函数的类型相同。不幸的是,如果y是一个numpy类型,则此方法不起作用,因此if(type(y).__module__ == np.__name__)使用此处介绍的思想来检测numpy类型并将其排除在类型转换之外。
将return type(y)(x_optimized)
if(type(y).__module__ == np.__name__)
实现的第一行scipy.optimize.fsolve是:
scipy.optimize.fsolve
x0 = array(x0, ndmin=1)
这意味着您的标量将变成1个元素的序列,而1个元素的序列将基本不变。
看起来有效的事实是实现细节,我将重构您的代码以不允许向中发送标量fsolve。我知道这似乎违背了鸭式输入法,但是该函数要求ndarray提供该参数,因此您应该尊重该接口,使其对实现中的更改具有鲁棒性。但是,我看不到有条件地x_guess = array(y, ndmin=1)用于将标量转换为ndarray包装函数中的值,并在必要时将结果转换回标量的任何问题。
fsolve
ndarray
x_guess = array(y, ndmin=1)
这是fsolve函数docstring的相关部分:
def fsolve(func, x0, args=(), fprime=None, full_output=0, col_deriv=0, xtol=1.49012e-8, maxfev=0, band=None, epsfcn=0.0, factor=100, diag=None): """ Find the roots of a function. Return the roots of the (non-linear) equations defined by ``func(x) = 0`` given a starting estimate. Parameters ---------- func : callable f(x, *args) A function that takes at least one (possibly vector) argument. x0 : ndarray The starting estimate for the roots of ``func(x) = 0``. ----SNIP---- Returns ------- x : ndarray The solution (or the result of the last iteration for an unsuccessful call). ----SNIP----