小编典典

强制NumPy ndarray取得其在Cython中的内存所有权

python

遵循以下答案:“我是否可以强制一个numpy ndarray占用其内存?”我试图PyArray_ENABLEFLAGS通过Cython的NumPy包装器使用Python
C API函数,但发现它没有公开。

以下尝试手动将其公开(这只是重现故障的最小示例)

from libc.stdlib cimport malloc
import numpy as np
cimport numpy as np

np.import_array()

ctypedef np.int32_t DTYPE_t

cdef extern from "numpy/ndarraytypes.h":
    void PyArray_ENABLEFLAGS(np.PyArrayObject *arr, int flags)

def test():
    cdef int N = 1000

    cdef DTYPE_t *data = <DTYPE_t *>malloc(N * sizeof(DTYPE_t))
    cdef np.ndarray[DTYPE_t, ndim=1] arr = np.PyArray_SimpleNewFromData(1, &N, np.NPY_INT32, data)
    PyArray_ENABLEFLAGS(arr, np.NPY_ARRAY_OWNDATA)

失败并出现编译错误:

Error compiling Cython file:
------------------------------------------------------------
...
def test():
    cdef int N = 1000

    cdef DTYPE_t *data = <DTYPE_t *>malloc(N * sizeof(DTYPE_t))
    cdef np.ndarray[DTYPE_t, ndim=1] arr = np.PyArray_SimpleNewFromData(1, &N, np.NPY_INT32, data)
    PyArray_ENABLEFLAGS(arr, np.NPY_ARRAY_OWNDATA)
                          ^
------------------------------------------------------------

/tmp/test.pyx:19:27: Cannot convert Python object to 'PyArrayObject *'

我的问题:
在这种情况下,这是正确的方法吗?如果是这样,我在做什么错?如果没有,如何在不使用C扩展模块的情况下强制NumPy拥有Cython的所有权?


阅读 214

收藏
2021-01-20

共1个答案

小编典典

接口定义中只有一些小错误。以下为我工作:

from libc.stdlib cimport malloc
import numpy as np
cimport numpy as np

np.import_array()

ctypedef np.int32_t DTYPE_t

cdef extern from "numpy/arrayobject.h":
    void PyArray_ENABLEFLAGS(np.ndarray arr, int flags)

cdef data_to_numpy_array_with_spec(void * ptr, np.npy_intp N, int t):
    cdef np.ndarray[DTYPE_t, ndim=1] arr = np.PyArray_SimpleNewFromData(1, &N, t, ptr)
    PyArray_ENABLEFLAGS(arr, np.NPY_OWNDATA)
    return arr

def test():
    N = 1000

    cdef DTYPE_t *data = <DTYPE_t *>malloc(N * sizeof(DTYPE_t))
    arr = data_to_numpy_array_with_spec(data, N, np.NPY_INT32)
    return arr

这是我的setup.py文件:

from distutils.core import setup, Extension
from Cython.Distutils import build_ext
ext_modules = [Extension("_owndata", ["owndata.pyx"])]
setup(cmdclass={'build_ext': build_ext}, ext_modules=ext_modules)

用构建python setup.py build_ext --inplace。然后验证数据是实际拥有的:

import _owndata
arr = _owndata.test()
print arr.flags

除其他外,您应该看到OWNDATA : True

而且 ,这绝对是处理这个正确的方式,因为numpy.pxd究竟到所有其他功能导出到用Cython同样的事情。

2021-01-20