我正在寻找一种巧妙的方法来提取大小为 2x2 的对角线块,这些对角线块位于 (2N)x(2N) numpy 数组的主对角线上(也就是说,将有 N 个这样的块)。这概括了 numpy.diag,它返回沿主对角线的元素,人们可能会将其视为 1x1 块(当然 numpy 不会以这种方式表示它们)。
为了更广泛地表达这一点,人们可能希望从 (MN)x(MN) 数组中提取 N MxM 块。这似乎是 scipy.linalg.block_diag 的补充,在How can I transform blocks into a blockdiagonal matrix (NumPy)中巧妙地讨论过,将块从 block_diag 放置的位置拉出。
从该问题的解决方案中借用代码,这是如何设置的:
>>> a1 = np.array([[1,1,1],[1,1,1],[1,1,1]])
>>> a2 = np.array([[2,2,2],[2,2,2],[2,2,2]])
>>> a3 = np.array([[3,3,3],[3,3,3],[3,3,3]])
>>> import scipy.linalg
>>> scipy.linalg.block_diag(a1, a2, a3)
array([[1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 2, 2, 2, 0, 0, 0],
[0, 0, 0, 2, 2, 2, 0, 0, 0],
[0, 0, 0, 2, 2, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 3, 3, 3],
[0, 0, 0, 0, 0, 0, 3, 3, 3],
[0, 0, 0, 0, 0, 0, 3, 3, 3]])
然后,我们希望有一个像
>>> A = scipy.linalg.block_diag(a1, a2, a3)
>>> extract_block_diag(A, M=3)
array([[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]],
[[2, 2, 2],
[2, 2, 2],
[2, 2, 2]],
[[3, 3, 3],
[3, 3, 3],
[3, 3, 3]]])
为了继续与 numpy.diag 进行类比,人们可能还希望提取非对角块:在第 k 个块对角线上的 N - k 个。(顺便说一句,block_diag 的扩展允许将块放置在主对角线之外肯定会有用,但这不是这个问题的范围。)对于上面的数组,这可能会产生:
>>> extract_block_diag(A, M=3, k=1)
array([[[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]])
我看到这个问题中涉及的 stride_tricks 的使用旨在产生类似于此功能的东西,但我知道跨步在字节级别上运行,这听起来不太合适。
作为动机,这源于我希望提取协方差矩阵的对角元素(即方差)的情况,其中元素本身不是标量而是 2x2 矩阵。
编辑:根据Chris 的建议,我做了以下尝试:
def extract_block_diag(A,M,k=0):
"""Extracts blocks of size M from the kth diagonal
of square matrix A, whose size must be a multiple of M."""
# Check that the matrix can be block divided
if A.shape[0] != A.shape[1] or A.shape[0] % M != 0:
raise StandardError('Matrix must be square and a multiple of block size')
# Assign indices for offset from main diagonal
if abs(k) > M - 1:
raise StandardError('kth diagonal does not exist in matrix')
elif k > 0:
ro = 0
co = abs(k)*M
elif k < 0:
ro = abs(k)*M
co = 0
else:
ro = 0
co = 0
blocks = np.array([A[i+ro:i+ro+M,i+co:i+co+M]
for i in range(0,len(A)-abs(k)*M,M)])
return blocks
其中将针对上述数据返回以下结果:
D = extract_block_diag(A,3)
[[[1 1 1]
[1 1 1]
[1 1 1]]
[[2 2 2]
[2 2 2]
[2 2 2]]
[[3 3 3]
[3 3 3]
[3 3 3]]]
D = extract_block_diag(A,3,-1)
[[[0 0 0]
[0 0 0]
[0 0 0]]
[[0 0 0]
[0 0 0]
[0 0 0]]]