1

我有一个由数组组成的列表(或数组)n。每个数组携带一个任意整数子集 from0n-1(数字在数组中不重复)。一个例子n=3是:

l = [np.array([0, 1]), np.array([0]), np.array([1, 2])]

我想从每个数组中选择一个数字作为其代表,这样没有两个数组具有相同的代表,并以与数组相同的顺序列出它们。换句话说,为数组选择的数字必须是唯一的,并且整个代表集将是数字的0排列n-1。对于上面的列表,它将是唯一的:

representatives = [1, 0, 2]

可以保证我们的名单中存在这样的代表名单,但我们如何找到他们。如果有多个可能的代表列表,则可以随机选择其中任何一个。

4

2 回答 2

2

这是你要找的吗?

def pick_one(a, index, buffer, visited):
    if index == len(a):
        return True
    for item in a[index]:
        if item not in visited:
            buffer.append(item)
            visited.add(item)
            if pick_one(a, index + 1, buffer, visited):
                return True
            buffer.pop()
            visited.remove(item)
    return False


a = [[0, 1], [0], [1, 2]]
buffer = []
pick_one(a, 0, buffer, set())
print(buffer)

输出:

[1, 0, 2]
于 2020-06-25T08:50:32.783 回答
1

您要求的是二分图的最大匹配,其左右集分别由您的数组及其唯一元素索引。

networkx模块知道如何找到这样的最大匹配:

import numpy as np
import networkx as nx
import operator as op

def make_example(n,density=0.1):
    rng = np.random.default_rng()
    M = np.unique(np.concatenate([rng.integers(0,n,(int(n*n*density),2)),
                                  np.stack([np.arange(n),rng.permutation(n)],
                                           axis=1)],axis=0),axis=0)
    return np.split(M[:,1],(M[:-1,0] != M[1:,0]).nonzero()[0])

def find_matching(M):
    G = nx.Graph()
    m = len(M)
    n = 1+max(map(max,M))
    G.add_nodes_from(range(n,m+n), biparite=0)
    G.add_nodes_from(range(n),biparite=1)
    G.add_edges_from((i,j) for i,r in enumerate(M,n) for j in r)
    return op.itemgetter(*range(n,m+n))(nx.bipartite.maximum_matching(G))

例子:

>>> M = make_example(10,0.4)
>>> M
[array([0, 4, 8]), array([9, 3, 5]), array([7, 1, 3, 4, 5, 7, 8]), array([9, 0, 4, 5]), array([9, 0, 1, 3, 5]), array([6, 0, 1, 2, 8]), array([9, 3, 5, 7]), array([8, 1, 2, 5]), array([6]), array([7, 0, 1, 4, 6])]
>>> find_matching(M)
(0, 9, 5, 4, 1, 2, 3, 8, 6, 7)

这可以在几秒钟内完成数千个元素:

>>> M = make_example(10000,0.01)
>>> t0,sol,t1 = [time.perf_counter(),find_matching(M),time.perf_counter()]
>>> print(t1-t0)
3.822795882006176
于 2020-06-25T22:47:31.983 回答