小编典典

scipy函数总是返回一个numpy数组

python

我遇到了一个scipy函数,无论传递给它什么,它似乎都返回一个numpy数组。在我的应用程序中,我只需要能够传递标量和列表,因此唯一的“问题”是当我将标量传递给函数时,将返回带有一个元素的数组(当我期望标量时)。我应该忽略这种行为,还是修改函数以确保在传递标量时返回标量?

示例代码:

#! /usr/bin/env python

import scipy
import scipy.optimize
from numpy import cos

# This a some function we want to compute the inverse of
def f(x):
    y = x + 2*cos(x)
    return y

# Given y, this returns x such that f(x)=y
def f_inverse(y):

    # This will be zero if f(x)=y
    def minimize_this(x):
        return y-f(x)

    # A guess for the solution is required
    x_guess = y
    x_optimized = scipy.optimize.fsolve(minimize_this, x_guess) # THE PROBLEM COMES FROM HERE
    return x_optimized

# If I call f_inverse with a list, a numpy array is returned
print f_inverse([1.0, 2.0, 3.0])
print type( f_inverse([1.0, 2.0, 3.0]) )

# If I call f_inverse with a tuple, a numpy array is returned
print f_inverse((1.0, 2.0, 3.0))
print type( f_inverse((1.0, 2.0, 3.0)) )

# If I call f_inverse with a scalar, a numpy array is returned
print f_inverse(1.0)
print type( f_inverse(1.0) )

# This is the behaviour I expected (scalar passed, scalar returned).
# Adding [0] on the return value is a hackey solution (then thing would break if a list were actually passed).
print f_inverse(1.0)[0] # <- bad solution
print type( f_inverse(1.0)[0] )

在我的系统上,此输出为:

[ 2.23872989  1.10914418  4.1187546 ]
<type 'numpy.ndarray'>
[ 2.23872989  1.10914418  4.1187546 ]
<type 'numpy.ndarray'>
[ 2.23872989]
<type 'numpy.ndarray'>
2.23872989209
<type 'numpy.float64'>

我正在使用MacPorts提供的SciPy 0.10.1和Python 2.7.3。

阅读以下答案后,我决定采用以下解决方案。将回车线替换为f_inverse:

if(type(y).__module__ == np.__name__):
    return x_optimized
else:
    return type(y)(x_optimized)

将return type(y)(x_optimized)导致返回类型与调用该函数的类型相同。不幸的是,如果y是一个numpy类型,则此方法不起作用,因此if(type(y).__module__ == np.__name__)使用此处介绍的思想来检测numpy类型并将其排除在类型转换之外。


阅读 231

收藏
2021-01-20

共1个答案

小编典典

实现的第一行scipy.optimize.fsolve是:

x0 = array(x0, ndmin=1)

这意味着您的标量将变成1个元素的序列,而1个元素的序列将基本不变。

看起来有效的事实是实现细节,我将重构您的代码以不允许向中发送标量fsolve。我知道这似乎违背了鸭式输入法,但是该函数要求ndarray提供该参数,因此您应该尊重该接口,使其对实现中的更改具有鲁棒性。但是,我看不到有条件地x_guess = array(y, ndmin=1)用于将标量转换为ndarray包装函数中的值,并在必要时将结果转换回标量的任何问题。

这是fsolve函数docstring的相关部分:

def fsolve(func, x0, args=(), fprime=None, full_output=0,
           col_deriv=0, xtol=1.49012e-8, maxfev=0, band=None,
           epsfcn=0.0, factor=100, diag=None):
    """
    Find the roots of a function.

    Return the roots of the (non-linear) equations defined by
    ``func(x) = 0`` given a starting estimate.

    Parameters
    ----------
    func : callable f(x, *args)
        A function that takes at least one (possibly vector) argument.
    x0 : ndarray
        The starting estimate for the roots of ``func(x) = 0``.

    ----SNIP----

    Returns
    -------
    x : ndarray
        The solution (or the result of the last iteration for
        an unsuccessful call).

    ----SNIP----
2021-01-20