0

我想在非方形 numpy 数组上使用 python 中的匈牙利分配算法。

我的输入矩阵X如下所示:

X = np.array([[0.26, 0.64, 0.16, 0.46, 0.5 , 0.63, 0.29],
              [0.49, 0.12, 0.61, 0.28, 0.74, 0.54, 0.25],
              [0.22, 0.44, 0.25, 0.76, 0.28, 0.49, 0.89],
              [0.56, 0.13, 0.45, 0.6 , 0.53, 0.56, 0.05],
              [0.66, 0.24, 0.61, 0.21, 0.47, 0.31, 0.35],
              [0.4 , 0.85, 0.45, 0.14, 0.26, 0.29, 0.24]])

X所需的结果是排序的矩阵,例如X_desired_output

X_desired_output = np.array([[0.63, 0.5 , 0.29, 0.46, 0.26, 0.64, 0.16], 
                             [0.54, 0.74, 0.25, 0.28, 0.49, 0.12, 0.61], 
                             [[0.49, 0.28, 0.89, 0.76, 0.22, 0.44, 0.25], 
                             [[0.56, 0.53, 0.05, 0.6 , 0.56, 0.13, 0.45], 
                             [[0.31, 0.47, 0.35, 0.21, 0.66, 0.24, 0.61], 
                             [[0.29, 0.26, 0.24, 0.14, 0.4 , 0.85, 0.45]])

在这里,我想最大化成本而不是最小化,因此算法的输入在理论上要么是,要么是1-X简单X的。

我发现https://software.clapper.org/munkres/导致:

from munkres import Munkres

m = Munkres()
indices = m.compute(-X)

indices
[(0, 5), (1, 4), (2, 6), (3, 3), (4, 0), (5, 1)]

# getting the indices in list format
ii = [i for (i,j) in indices]
jj = [j for (i,j) in indices]

我怎样才能使用这些来排序Xjj仅包含 6 个元素,而不是原始的 7 列X

我正在寻找实际排序的矩阵。

4

1 回答 1

0

在花了几个小时研究它之后,我找到了一个解决方案。问题是由于X.shape[1] > X.shape[0]某些列根本没有分配,这导致了问题。

该文件指出

“Munkres 算法假设成本矩阵是方形的。但是,如果您首先用 0 值填充它以使其成为方形,则可以使用矩形矩阵。该模块会自动填充矩形成本矩阵以使其成为方形。”

from munkres import Munkres

m = Munkres()
indices = m.compute(-X)

indices
[(0, 5), (1, 4), (2, 6), (3, 3), (4, 0), (5, 1)]

# getting the indices in list format
ii = [i for (i,j) in indices]
jj = [j for (i,j) in indices]

# re-order matrix
X_=X[:,jj]  # re-order columns
X_=X_[ii,:] # re-order rows

# HERE IS THE TRICK: since the X is not diagonal, some columns are not assigned to the rows !
not_assigned_columns = X[:, [not_assigned for not_assigned in np.arange(X.shape[1]).tolist() if not_assigned not in jj]].reshape(-1,1)

X_desired = np.concatenate((X_, not_assigned_columns), axis=1)

print(X_desired)

array([[0.63, 0.5 , 0.29, 0.46, 0.26, 0.64, 0.16],
       [0.54, 0.74, 0.25, 0.28, 0.49, 0.12, 0.61],
       [0.49, 0.28, 0.89, 0.76, 0.22, 0.44, 0.25],
       [0.56, 0.53, 0.05, 0.6 , 0.56, 0.13, 0.45],
       [0.31, 0.47, 0.35, 0.21, 0.66, 0.24, 0.61],
       [0.29, 0.26, 0.24, 0.14, 0.4 , 0.85, 0.45]])
于 2021-11-16T15:38:54.690 回答