Python scipy.linalg 模块,get_blas_funcs() 实例源码

我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用scipy.linalg.get_blas_funcs()

项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def _fast_dot(A, B):
    if B.shape[0] != A.shape[A.ndim - 1]:  # check adopted from '_dotblas.c'
        raise ValueError

    if A.dtype != B.dtype or any(x.dtype not in (np.float32, np.float64)
                                 for x in [A, B]):
        warnings.warn('Falling back to np.dot. '
                      'Data must be of same type of either '
                      '32 or 64 bit float for the BLAS function, gemm, to be '
                      'used for an efficient dot operation. ',
                      NonBLASDotWarning)
        raise ValueError

    if min(A.shape) == 1 or min(B.shape) == 1 or A.ndim != 2 or B.ndim != 2:
        raise ValueError

    # scipy 0.9 compliant API
    dot = linalg.get_blas_funcs(['gemm'], (A, B))[0]
    A, trans_a = _impose_f_order(A)
    B, trans_b = _impose_f_order(B)
    return dot(alpha=1.0, a=A, b=B, trans_a=trans_a, trans_b=trans_b)
项目:Lyssandra    作者:ektormak    | 项目源码 | 文件源码
def norm(x):
    nrm2 = get_blas_funcs(['nrm2'],[x])[0]
    return nrm2(x)
项目:gcMapExplorer    作者:rjdkmr    | 项目源码 | 文件源码
def mdot(v, w):
    gemm = get_blas_funcs("gemm", [v, w])
    return gemm(alpha=1., a=v, b=w)
项目:pactools    作者:pactools    | 项目源码 | 文件源码
def norm(x):
    """Compute the Euclidean or Frobenius norm of x.

    Returns the Euclidean norm when x is a vector, the Frobenius norm when x
    is a matrix (2-d array). More precise than sqrt(squared_norm(x)).
    """
    x = np.asarray(x)

    if np.any(np.iscomplex(x)):
        return np.sqrt(squared_norm(x))
    else:
        nrm2, = linalg.get_blas_funcs(['nrm2'], [x])
        return nrm2(x)
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def norm(x):
    """Compute the Euclidean or Frobenius norm of x.

    Returns the Euclidean norm when x is a vector, the Frobenius norm when x
    is a matrix (2-d array). More precise than sqrt(squared_norm(x)).
    """
    x = np.asarray(x)
    nrm2, = linalg.get_blas_funcs(['nrm2'], [x])
    return nrm2(x)


# Newer NumPy has a ravel that needs less copying.
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def _have_blas_gemm():
    try:
        linalg.get_blas_funcs(['gemm'])
        return True
    except (AttributeError, ValueError):
        warnings.warn('Could not import BLAS, falling back to np.dot')
        return False


# Only use fast_dot for older NumPy; newer ones have tackled the speed issue.
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def test_inplace_swap_row():
    X = np.array([[0, 3, 0],
                  [2, 4, 0],
                  [0, 0, 0],
                  [9, 8, 7],
                  [4, 0, 5]], dtype=np.float64)
    X_csr = sp.csr_matrix(X)
    X_csc = sp.csc_matrix(X)

    swap = linalg.get_blas_funcs(('swap',), (X,))
    swap = swap[0]
    X[0], X[-1] = swap(X[0], X[-1])
    inplace_swap_row(X_csr, 0, -1)
    inplace_swap_row(X_csc, 0, -1)
    assert_array_equal(X_csr.toarray(), X_csc.toarray())
    assert_array_equal(X, X_csc.toarray())
    assert_array_equal(X, X_csr.toarray())

    X[2], X[3] = swap(X[2], X[3])
    inplace_swap_row(X_csr, 2, 3)
    inplace_swap_row(X_csc, 2, 3)
    assert_array_equal(X_csr.toarray(), X_csc.toarray())
    assert_array_equal(X, X_csc.toarray())
    assert_array_equal(X, X_csr.toarray())
    assert_raises(TypeError, inplace_swap_row, X_csr.tolil())

    X = np.array([[0, 3, 0],
                  [2, 4, 0],
                  [0, 0, 0],
                  [9, 8, 7],
                  [4, 0, 5]], dtype=np.float32)
    X_csr = sp.csr_matrix(X)
    X_csc = sp.csc_matrix(X)
    swap = linalg.get_blas_funcs(('swap',), (X,))
    swap = swap[0]
    X[0], X[-1] = swap(X[0], X[-1])
    inplace_swap_row(X_csr, 0, -1)
    inplace_swap_row(X_csc, 0, -1)
    assert_array_equal(X_csr.toarray(), X_csc.toarray())
    assert_array_equal(X, X_csc.toarray())
    assert_array_equal(X, X_csr.toarray())
    X[2], X[3] = swap(X[2], X[3])
    inplace_swap_row(X_csr, 2, 3)
    inplace_swap_row(X_csc, 2, 3)
    assert_array_equal(X_csr.toarray(), X_csc.toarray())
    assert_array_equal(X, X_csc.toarray())
    assert_array_equal(X, X_csr.toarray())
    assert_raises(TypeError, inplace_swap_row, X_csr.tolil())
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def test_inplace_swap_column():
    X = np.array([[0, 3, 0],
                  [2, 4, 0],
                  [0, 0, 0],
                  [9, 8, 7],
                  [4, 0, 5]], dtype=np.float64)
    X_csr = sp.csr_matrix(X)
    X_csc = sp.csc_matrix(X)

    swap = linalg.get_blas_funcs(('swap',), (X,))
    swap = swap[0]
    X[:, 0], X[:, -1] = swap(X[:, 0], X[:, -1])
    inplace_swap_column(X_csr, 0, -1)
    inplace_swap_column(X_csc, 0, -1)
    assert_array_equal(X_csr.toarray(), X_csc.toarray())
    assert_array_equal(X, X_csc.toarray())
    assert_array_equal(X, X_csr.toarray())

    X[:, 0], X[:, 1] = swap(X[:, 0], X[:, 1])
    inplace_swap_column(X_csr, 0, 1)
    inplace_swap_column(X_csc, 0, 1)
    assert_array_equal(X_csr.toarray(), X_csc.toarray())
    assert_array_equal(X, X_csc.toarray())
    assert_array_equal(X, X_csr.toarray())
    assert_raises(TypeError, inplace_swap_column, X_csr.tolil())

    X = np.array([[0, 3, 0],
                  [2, 4, 0],
                  [0, 0, 0],
                  [9, 8, 7],
                  [4, 0, 5]], dtype=np.float32)
    X_csr = sp.csr_matrix(X)
    X_csc = sp.csc_matrix(X)
    swap = linalg.get_blas_funcs(('swap',), (X,))
    swap = swap[0]
    X[:, 0], X[:, -1] = swap(X[:, 0], X[:, -1])
    inplace_swap_column(X_csr, 0, -1)
    inplace_swap_column(X_csc, 0, -1)
    assert_array_equal(X_csr.toarray(), X_csc.toarray())
    assert_array_equal(X, X_csc.toarray())
    assert_array_equal(X, X_csr.toarray())
    X[:, 0], X[:, 1] = swap(X[:, 0], X[:, 1])
    inplace_swap_column(X_csr, 0, 1)
    inplace_swap_column(X_csc, 0, 1)
    assert_array_equal(X_csr.toarray(), X_csc.toarray())
    assert_array_equal(X, X_csc.toarray())
    assert_array_equal(X, X_csr.toarray())
    assert_raises(TypeError, inplace_swap_column, X_csr.tolil())