我正在尝试使用numpy.lib.stride_tricks.as_strided遍历数组的非重叠块,但是在查找参数文档时遇到了麻烦,因此我只能获得重叠块。
例如,我有一个4x5数组,我想从中获取4个2x2块。我可以排除右侧和底部边缘的多余单元格。
到目前为止,我的代码是:
import sys import numpy as np a = np.array([ [1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15], [16,17,18,19,20], ]) sz = a.itemsize h,w = a.shape bh,bw = 2,2 shape = (h/bh, w/bw, bh, bw) strides = (w*sz, sz, w*sz, sz) blocks = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) print blocks[0][0] assert blocks[0][0].tolist() == [[1, 2], [6,7]] print blocks[0][1] assert blocks[0][1].tolist() == [[3,4], [8,9]] print blocks[1][0] assert blocks[1][0].tolist() == [[11, 12], [16, 17]]
结果得到的blocks数组的形状似乎正确,但是最后两个断言失败了,大概是因为我的shape或stride参数不正确。我应该设置这些值的什么值以获得不重叠的块?
import numpy as np n=4 m=5 a = np.arange(1,n*m+1).reshape(n,m) print(a) # [[ 1 2 3 4 5] # [ 6 7 8 9 10] # [11 12 13 14 15] # [16 17 18 19 20]] sz = a.itemsize h,w = a.shape bh,bw = 2,2 shape = (h/bh, w/bw, bh, bw) print(shape) # (2, 2, 2, 2) strides = sz*np.array([w*bh,bw,w,1]) print(strides) # [40 8 20 4] blocks=np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) print(blocks) # [[[[ 1 2] # [ 6 7]] # [[ 3 4] # [ 8 9]]] # [[[11 12] # [16 17]] # [[13 14] # [18 19]]]]
从1in a(即blocks[0,0,0,0])开始,到2(即blocks[0,0,0,1])的距离只有一项。由于(在我的机器上)a.itemsize为4个字节,因此跨度为1 * 4 =4。这为我们提供了中的最后一个值strides = (10,2,5,1)*a.itemsize = (40,8,20,4)。
1
a
blocks[0,0,0,0]
2
blocks[0,0,0,1]
a.itemsize
strides = (10,2,5,1)*a.itemsize = (40,8,20,4)
从另一处开始1,要到达6(即blocks[0,0,1,0])距离5(即w)个项目,因此跨步为5 * 4 =20。这占中倒数第二个值strides。
6
blocks[0,0,1,0]
w
strides
从另一处开始1,要到达3(即blocks[0,1,0,0]),便是2(即bw)项,因此步幅为2 * 4 =8。这占的第二个值strides。
3
blocks[0,1,0,0]
bw
最后,从出发1,到达11(即blocks[1,0,0,0])是10(即w*bh)项,因此跨度为10 * 4 = 40 strides = (40,8,20,4)。
11
blocks[1,0,0,0]
w*bh
strides = (40,8,20,4)