小编典典

防止numpy创建多维数组

python

创建数组时,NumPy非常有用。如果for的第一个参数numpy.array具有__getitem__and__len__方法,则根据它们可能是有效序列使用它们。

不幸的是,我想创建一个包含dtype=object没有NumPy是“有用的”的数组。

分解为一个最小的示例,该类将如下所示:

import numpy as np

class Test(object):
    def __init__(self, iterable):
        self.data = iterable

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, self.data)

如果“可迭代对象”的长度不同,那么一切都很好,而我得到的结果恰好是我想要的:

>>> np.array([Test([1,2,3]), Test([3,2])], dtype=object)
array([Test([1, 2, 3]), Test([3, 2])], dtype=object)

但是NumPy会创建一个多维数组,如果它们恰好具有相同的长度:

>>> np.array([Test([1,2,3]), Test([3,2,1])], dtype=object)
array([[1, 2, 3],
       [3, 2, 1]], dtype=object)

不幸的是,只有一个ndmin参数,所以我想知道是否有一种方法可以强制ndmaxNumPy或以某种方式阻止NumPy将自定义类解释为另一个维度(不删除__len____getitem__)?


阅读 214

收藏
2020-12-20

共1个答案

小编典典

解决方法当然是创建所需形状的数组,然后复制数据:

In [19]: lst = [Test([1, 2, 3]), Test([3, 2, 1])]

In [20]: arr = np.empty(len(lst), dtype=object)

In [21]: arr[:] = lst[:]

In [22]: arr
Out[22]: array([Test([1, 2, 3]), Test([3, 2, 1])], dtype=object)

请注意,无论如何,如果解释可迭代对象的numpy行为(您要使用的是吧?)与numpy版本相关,我不会感到惊讶。甚至可能是越野车。也许其中一些错误实际上是功能。无论如何,当更改Numpy版本时,我会小心避免损坏。

相反,复制到预先创建的数组中应该更可靠。

2020-12-20