12

假设我有一个 CSR 格式的矩阵,将一行(或多行)设置为零的最有效方法是什么?

以下代码运行非常缓慢:

A = A.tolil()
A[indices, :] = 0
A = A.tocsr()

我不得不转换为,scipy.sparse.lil_matrix因为 CSR 格式似乎既不支持花哨的索引也不支持为切片设置值。

4

4 回答 4

10

我猜 scipy 只是没有实现它,但 CSR 格式会很好地支持这一点,请阅读关于“稀疏矩阵”的维基百科文章,了解什么indptr等是:

# A.indptr is an array, one for each row (+1 for the nnz):

def csr_row_set_nz_to_val(csr, row, value=0):
    """Set all nonzero elements (elements currently in the sparsity pattern)
    to the given value. Useful to set to 0 mostly.
    """
    if not isinstance(csr, scipy.sparse.csr_matrix):
        raise ValueError('Matrix given must be of CSR format.')
    csr.data[csr.indptr[row]:csr.indptr[row+1]] = value

# Now you can just do:
for row in indices:
    csr_row_set_nz_to_val(A, row, 0)

# And to remove zeros from the sparsity pattern:
A.eliminate_zeros()

当然,这会eliminate_zeros从稀疏模式中删除从另一个地方设置的 0。如果你想这样做(此时)取决于你真正在做什么,即。延迟消除可能是有意义的,直到所有其他可能添加新零的计算也完成,或者在某些情况下,您可能有 0 个值,您想稍后再次更改,因此消除它们会非常糟糕!

原则上你当然可以短路eliminate_zerosand prune,但这应该很麻烦,而且可能更慢(因为你不会在 C 中这样做)。


有关 eliminiate_zeros (和修剪)的详细信息

稀疏矩阵通常不保存零元素,而只是存储非零元素所在的位置(大致并使用各种方法)。eliminate_zeros从稀疏模式中删除矩阵中的所有零(即,没有为该位置存储任何值,而之前存储了一个值,但它是 0)。如果您想稍后将 0 更改为不同的值,则消除是不好的,否则会节省空间。

Prune 只会在存储的数据数组超过必要时收缩它们。请注意,虽然我第一次A.prune()在那里,A.eliminiate_zeros()已经包括修剪。

于 2012-08-26T12:54:45.113 回答
1

您可以使用矩阵点积来实现归零。由于我们将使用的矩阵非常稀疏(我们要清零的行/列的对角线为零),乘法应该是有效的。

您将需要以下功能之一:

import scipy.sparse

def zero_rows(M, rows):
    diag = scipy.sparse.eye(M.shape[0]).tolil()
    for r in rows:
        diag[r, r] = 0
    return diag.dot(M)

def zero_columns(M, columns):
    diag = scipy.sparse.eye(M.shape[1]).tolil()
    for c in columns:
        diag[c, c] = 0
    return M.dot(diag)

使用示例:

>>> A = scipy.sparse.csr_matrix([[1,0,3,4], [5,6,0,8], [9,10,11,0]])
>>> A
<3x4 sparse matrix of type '<class 'numpy.int64'>'
        with 9 stored elements in Compressed Sparse Row format>
>>> A.toarray()
array([[ 1,  0,  3,  4],
       [ 5,  6,  0,  8],
       [ 9, 10, 11,  0]], dtype=int64)

>>> B = zero_rows(A, [1])
>>> B
<3x4 sparse matrix of type '<class 'numpy.float64'>'
        with 6 stored elements in Compressed Sparse Row format>
>>> B.toarray()
array([[  1.,   0.,   3.,   4.],
       [  0.,   0.,   0.,   0.],
       [  9.,  10.,  11.,   0.]])

>>> C = zero_columns(A, [1, 3])
>>> C
<3x4 sparse matrix of type '<class 'numpy.float64'>'
        with 5 stored elements in Compressed Sparse Row format>
>>> C.toarray()
array([[  1.,   0.,   3.,   0.],
       [  5.,   0.,   0.,   0.],
       [  9.,   0.,  11.,   0.]])
于 2017-03-30T09:59:57.070 回答
0

更新到最新版本的 scipy。它支持精美的索引。

于 2014-02-18T17:46:00.233 回答
0

我想完成@seberg 给出的答案。如果要将 nnz 值设置为零,则应修改 CSR 矩阵的结构,而不仅仅是修改.data属性。

此代码的当前行为是,

>>> import scipy.sparse
>>> import numpy as np
>>> A = scipy.sparse.csr_matrix([[0,1,0], [2,0,3], [0,0,0], [4,0,0]])
>>> A.toarray()
array([[0, 1, 0],
       [2, 0, 3],
       [0, 0, 0],
       [4, 0, 0]], dtype=int64)
>>> csr_row_set_nz_to_val(A, 1)
>>> A.toarray()
array([[0, 1, 0],
       [0, 0, 0],
       [0, 0, 0],
       [4, 0, 0]], dtype=int64)
>>> A.data
array([1, 0, 0, 4], dtype=int64)
>>> A.indices
array([1, 0, 2, 0], dtype=int32)
>>> A.indptr
array([0, 1, 3, 3, 4], dtype=int32)

因为我们正在处理稀疏矩阵,所以我们不希望A.data数组中出现零。我认为应该修改csr_row_set_nz_to_val如下

def csr_row_set_nz_to_val(csr, row, value=0):
    """Set all nonzero elements of a CSR matrix M (elements currently in the sparsity pattern)
    to the given value. Useful to set to 0 mostly.
    """
    if not isinstance(csr, scipy.sparse.csr_matrix):
        raise ValueError("Matrix given must be of CSR format.")
    if value == 0:
        csr.data = np.delete(csr.data, range(csr.indptr[row], csr.indptr[row+1])) # drop nnz values
        csr.indices = np.delete(csr.indices, range(csr.indptr[row], csr.indptr[row+1])) # drop nnz column indices
        csr.indptr[(row+1):] = csr.indptr[(row+1):] - (csr.indptr[row+1] - csr.indptr[row])
    else:
        csr.data[csr.indptr[row]:csr.indptr[row+1]] = value # replace nnz values by another nnz value

最后,我们会得到

>>> A = scipy.sparse.csr_matrix([[0,1,0], [2,0,3], [0,0,0], [4,0,0]])
>>> csr_row_set_nz_to_val(A, 1)
>>> A.toarray()
array([[0, 1, 0],
       [0, 0, 0],
       [0, 0, 0],
       [4, 0, 0]], dtype=int64)
>>> A.data
array([1, 4], dtype=int64)
>>> A.indices
array([1, 0], dtype=int32)
>>> A.indptr
array([0, 1, 1, 1, 2], dtype=int32)
于 2020-06-18T15:48:13.757 回答