由于您是在矩阵及其转置的乘积之后,因此 at 的值[m, n]
基本上将是列m
和n
原始矩阵中的点积。
我将使用以下矩阵作为玩具示例
a = np.array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1]])
>>> np.dot(a.T, a)
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 2]])
它有形状(3, 12)
并且有 7 个非零条目。它与它的转置的乘积当然是有形状的(12, 12)
,并且有 16 个非零条目,其中 6 个在对角线上,所以它只需要存储 11 个元素。
您可以通过以下两种方式之一很好地了解输出矩阵的大小:
企业社会责任格式
如果您的原始矩阵有C
非零列,那么您的新矩阵将最多有C**2
非零条目,其中C
在对角线上,并确保不为零,而其余条目您只需要保留一半,所以最多是(C**2 + C) / 2
非零元素。当然,其中许多也将为零,所以这可能是一个严重的高估。
如果您的矩阵以csr
格式存储,则indices
相应scipy
对象的属性具有一个数组,其中包含所有非零元素的列索引,因此您可以轻松地将上述估计计算为:
>>> a_csr = scipy.sparse.csr_matrix(a)
>>> a_csr.indices
array([ 2, 11, 1, 7, 10, 4, 11])
>>> np.unique(a_csr.indices).shape[0]
6
所以有 6 列具有非零条目,因此估计最多有 36 个非零条目,比实际的 16 多得多。
CSC 格式
如果我们有行索引而不是非零元素的列索引,我们实际上可以做一个更好的估计。为了使两列的点积不为零,它们必须在同一行中有一个非零元素。如果R
给定行中有非零元素,它们将为R**2
乘积贡献非零元素。当你对所有行求和时,你一定会多次计算一些元素,所以这也是一个上限。
矩阵的非零元素的行索引位于indices
稀疏 csc 矩阵的属性中,因此可以按如下方式计算此估计:
>>> a_csc = scipy.sparse.csc_matrix(a)
>>> a_csc.indices
array([1, 0, 2, 1, 1, 0, 2])
>>> rows, where = np.unique(a_csc.indices, return_inverse=True)
>>> where = np.bincount(where)
>>> rows
array([0, 1, 2])
>>> where
array([2, 3, 2])
>>> np.sum(where**2)
17
这与真正的16非常接近!而且这个估计实际上与以下内容相同并非巧合:
>>> np.sum(np.dot(a.T,a),axis=None)
17
无论如何,下面的代码应该可以让您看到估计非常好:
def estimate(a) :
a_csc = scipy.sparse.csc_matrix(a)
_, where = np.unique(a_csc.indices, return_inverse=True)
where = np.bincount(where)
return np.sum(where**2)
def test(shape=(10,1000), count=100) :
a = np.zeros(np.prod(shape), dtype=int)
a[np.random.randint(np.prod(shape), size=count)] = 1
print 'a non-zero = {0}'.format(np.sum(a))
a = a.reshape(shape)
print 'a.T * a non-zero = {0}'.format(np.flatnonzero(np.dot(a.T,
a)).shape[0])
print 'csc estimate = {0}'.format(estimate(a))
>>> test(count=100)
a non-zero = 100
a.T * a non-zero = 1065
csc estimate = 1072
>>> test(count=200)
a non-zero = 199
a.T * a non-zero = 4056
csc estimate = 4079
>>> test(count=50)
a non-zero = 50
a.T * a non-zero = 293
csc estimate = 294