-1

我构建了一个与 Kruskal 的 MST 算法一起使用的不相交集数据结构。我需要加载并合并一个具有 200k 个互连节点的图,我认为我的数据结构实现是一个瓶颈。

您对如何提高性能有什么建议吗?我认为我的 find 方法可能有问题。

class partition(object):
    def __init__(self, element=None):
        self.size = 0
        if element == None:
            self.contents = set()
            self.representative = None
        else:
            self.contents = {element}
            self.representative = element
            self.size = 1

    def find(self, element):
        return element in self.contents

    def add(self, partition):
        self.contents = self.contents.union(partition)
        self.size = len(self.contents)

    def show(self):
        return self.contents

    def __repr__(self):
        return str(self.contents)

class disjoint_set(object):
    def __init__(self):
        self.partitions_count = 0
        self.forest = {}

    def make_set(self, element):
        if self.find(element) == False:
            new_partition = partition(element)
            self.forest[new_partition.representative] = new_partition
            self.partitions_count += 1

    def union(self, x, y):
        if x != y:
            if self.forest[x].size < self.forest[y].size:
                self.forest[y].add(self.forest[x].show())
                self.delete(x)
            else:
                self.forest[x].add(self.forest[y].show())
                self.delete(y)

    def find(self, element):
        for partition in self.forest.keys():
            if self.forest[partition].find(element):
                return self.forest[partition].representative
        return False

    def delete(self, partition):
        del self.forest[partition]
        self.partitions_count -= 1

if __name__ == '__main__':
    t = disjoint_set()
    t.make_set(1)
    t.make_set(2)
    t.make_set(3)
    print("Create 3 singleton partitions:")
    print(t.partitions_count)
    print(t.forest)
    print("Union two into a single partition:")
    t.union(1,2)
    print(t.forest)
    print(t.partitions_count)

编辑:

在阅读了评论并进行了额外的研究后,我意识到我的原始算法设计得多么糟糕。我从头开始,把它放在一起。我将所有分区放入一个哈希表中,并在 find() 中使用了路径压缩。这看起来如何?我应该解决什么明显的问题?

class disjoint_set(object):
def __init__(self):
    self.partitions_count = 0
    self.size = {}
    self.parent = {}

def make_set(self, element):
    if self.find(element) == False:
        self.parent[element] = element
        self.size[element] = 1
        self.partitions_count += 1

def union(self, x, y):
    xParent = self.find(x)
    yParent = self.find(y)
    if xParent != yParent:
        if self.size[xParent] < self.size[yParent]:
            self.parent[xParent] = yParent
            self.size[yParent] += self.size[xParent]
            self.partitions_count -= 1
        else:
            self.parent[yParent] = xParent
            self.size[xParent] += self.size[yParent]
            self.partitions_count -= 1

def find(self, element):
    if element in self.parent:
        if element == self.parent[element]:
            return element
        root = self.parent[element]
        while self.parent[root] != root:
            root = self.find(self.parent[root])
        self.parent[element] = root
        return root
    return False

if __name__ == '__main__':
    t = disjoint_set()
    t.make_set(1)
    t.make_set(2)
    t.make_set(3)
    t.make_set(4)
    t.make_set(5)
    print("Create 5 singleton partitions")
    print(t.partitions_count)
    print("Union two singletons into a single partition")
    t.union(1,2)
    print("Union three singletones into a single partition")
    t.union(3,4)
    t.union(5,4)
    print("Union a single partition")
    t.union(2,4)
    print("Parent List: %s" % t.parent)
    print("Partition Count: %s" % t.partitions_count)
    print("Parent of element 2: %s" % t.find(2))
    print("Parent List: %s" % t.parent)
4

1 回答 1

0

我猜你的 find 实现没有有效地运行,它应该是。

以下更改可能会有所帮助。

class disjoint_set(object):
    def __init__(self):
        self.partitions_count = 0
        self.forest = {}
        self.parent = {}

    def make_set(self, element):
        if not self.find(element):
            new_partition = partition(element)
            self.parent[element] = element
            self.forest[new_partition.representative] = new_partition
            self.partitions_count += 1

def union(self, x, y):
    if x != y:
        if self.forest[x].size < self.forest[y].size:
            self.forest[y].add(self.forest[x].show())
            #Update parent details 
            self.parent[self.forest[x].representative] = self.forest[y].representative
            self.delete(x)
        else:
            self.forest[x].add(self.forest[y].show())
            #Update parent details 
            self.parent[self.forest[y].representative] = self.forest[x].representative
            self.delete(y)

def find(self, element):
    if self.parent[element] == element:
        return element
    else:
        return find(element)

代码仍然可以通过路径压缩进行优化,以使 disjoint_set.find 在 O(1) 中运行。我猜 O(log n) 仍然适用于大数字。

另一个瓶颈可能是您的联合功能。尤其是 add 函数的实现。

def add(self, partition):
    self.contents = self.contents.union(partition)

尝试使用 set 的更新方法(这是一个就地联合)。我认为这会导致大量节点的大量内存开销。就像是

self.contents.update(partition)

这里有关于 set union 和 update 函数的很好的讨论。

希望能帮助到你!

于 2017-04-27T00:06:21.410 回答