小编典典

python numpy.where()如何工作?

python

我正在玩耍numpy并浏览文档,并且遇到了一些魔术。即我正在谈论numpy.where()

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

它们如何在内部实现您能够将类似的东西传递x > 5给方法的功能?我想这与它有关,__gt__但是我正在寻找详细的解释。


阅读 232

收藏
2020-12-20

共1个答案

小编典典

他们如何在内部实现将x> 5之类的内容传递给方法的能力?

简短的答案是他们没有。

对numpy数组进行任何形式的逻辑运算都会返回一个布尔数组。(即__gt__,,__lt__等等都返回给定条件为true的布尔数组)。

例如

x = np.arange(9).reshape(3,3)
print x > 5

产量:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

这就是为什么类似的东西if x > 5:如果x是一个numpy数组会引发ValueError的原因。它是True /
False值的数组,而不是单个值。

此外,numpy数组可以由布尔数组索引。例如,在这种情况下,x[x>5]yields [6 7 8]

老实说,您实际需要的很少,numpy.where但它只返回布尔数组为的索引True。通常,您可以使用简单的布尔索引执行所需的操作。

2020-12-20