我正在玩numpy
和挖掘文档,我遇到了一些魔法。即我在说numpy.where()
:
>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))
他们如何在内部实现您能够将类似的东西传递x > 5
给方法?我想这与它有关,__gt__
但我正在寻找详细的解释。
我正在玩numpy
和挖掘文档,我遇到了一些魔法。即我在说numpy.where()
:
>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))
他们如何在内部实现您能够将类似的东西传递x > 5
给方法?我想这与它有关,__gt__
但我正在寻找详细的解释。
他们如何在内部实现您能够将 x > 5 之类的东西传递给方法?
简短的回答是他们没有。
numpy 数组上的任何类型的逻辑操作都会返回一个布尔数组。(即__gt__
,__lt__
等都返回给定条件为真的布尔数组)。
例如
x = np.arange(9).reshape(3,3)
print x > 5
产量:
array([[False, False, False],
[False, False, False],
[ True, True, True]], dtype=bool)
这就是为什么如果是一个 numpy 数组if x > 5:
会引发 ValueError之类的东西的原因。x
它是一组真/假值,而不是单个值。
此外,numpy 数组可以由布尔数组索引。例如,在这种情况下,x[x>5]
产量。[6 7 8]
老实说,您真正需要的情况很少见,numpy.where
但它只返回布尔数组所在的索引True
。通常你可以用简单的布尔索引来做你需要的事情。
旧答案 有点令人困惑。它为您提供了您的陈述正确的位置(所有这些位置)。
所以:
>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
99]),)
>>> np.where(a == 90)
(array([90]),)
a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040
我将它用作 list.index() 的替代方法,但它也有许多其他用途。我从未将它与二维数组一起使用。
http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html
新答案 似乎这个人在问一些更根本的问题。
问题是您如何实现允许功能(例如在哪里)知道所请求的内容的东西。
首先请注意,调用任何比较运算符都会做一件有趣的事情。
a > 1000
array([False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True`, True, True, True, True, True, True, True, True, True], dtype=bool)`
这是通过重载“__gt__”方法来完成的。例如:
>>> class demo(object):
def __gt__(self, item):
print item
>>> a = demo()
>>> a > 4
4
如您所见,“a > 4”是有效代码。
您可以在此处获取所有重载函数的完整列表和文档:http: //docs.python.org/reference/datamodel.html
令人难以置信的是做到这一点是多么简单。python中的所有操作都是以这种方式完成的。说 a > b 等价于 a。gt(乙)!
np.where
返回一个长度等于调用它的 numpy ndarray 的维度的元组(换句话说ndim
),并且元组的每个项目都是一个 numpy ndarray,其中包含条件为 True 的初始 ndarray 中所有这些值的索引。(请不要将尺寸与形状混淆)
例如:
x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))
y 是长度为 2 的元组,因为x.ndim
是 2。元组中的第一项包含所有大于 4 的元素的行号,第二项包含所有大于 4 的项的列号。如您所见,[1,2,2 ,2] 对应于 5,6,7,8 的行号,[2,0,1,2] 对应于 5,6,7,8 的列号注意,ndarray 是沿第一维遍历的(按行)。
相似地,
x=np.arange(27).reshape(3,3,3)
np.where(x>4)
将返回一个长度为 3 的元组,因为 x 有 3 个维度。
但是等等,np.where 还有更多内容!
当两个额外的参数被添加到np.where
; 它将对上述元组获得的所有成对行列组合进行替换操作。
x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
[0, 0, 1],
[1, 1, 1]])