我构建了一个与 Kruskal 的 MST 算法一起使用的不相交集数据结构。我需要加载并合并一个具有 200k 个互连节点的图,我认为我的数据结构实现是一个瓶颈。
您对如何提高性能有什么建议吗?我认为我的 find 方法可能有问题。
class partition(object):
def __init__(self, element=None):
self.size = 0
if element == None:
self.contents = set()
self.representative = None
else:
self.contents = {element}
self.representative = element
self.size = 1
def find(self, element):
return element in self.contents
def add(self, partition):
self.contents = self.contents.union(partition)
self.size = len(self.contents)
def show(self):
return self.contents
def __repr__(self):
return str(self.contents)
class disjoint_set(object):
def __init__(self):
self.partitions_count = 0
self.forest = {}
def make_set(self, element):
if self.find(element) == False:
new_partition = partition(element)
self.forest[new_partition.representative] = new_partition
self.partitions_count += 1
def union(self, x, y):
if x != y:
if self.forest[x].size < self.forest[y].size:
self.forest[y].add(self.forest[x].show())
self.delete(x)
else:
self.forest[x].add(self.forest[y].show())
self.delete(y)
def find(self, element):
for partition in self.forest.keys():
if self.forest[partition].find(element):
return self.forest[partition].representative
return False
def delete(self, partition):
del self.forest[partition]
self.partitions_count -= 1
if __name__ == '__main__':
t = disjoint_set()
t.make_set(1)
t.make_set(2)
t.make_set(3)
print("Create 3 singleton partitions:")
print(t.partitions_count)
print(t.forest)
print("Union two into a single partition:")
t.union(1,2)
print(t.forest)
print(t.partitions_count)
编辑:
在阅读了评论并进行了额外的研究后,我意识到我的原始算法设计得多么糟糕。我从头开始,把它放在一起。我将所有分区放入一个哈希表中,并在 find() 中使用了路径压缩。这看起来如何?我应该解决什么明显的问题?
class disjoint_set(object):
def __init__(self):
self.partitions_count = 0
self.size = {}
self.parent = {}
def make_set(self, element):
if self.find(element) == False:
self.parent[element] = element
self.size[element] = 1
self.partitions_count += 1
def union(self, x, y):
xParent = self.find(x)
yParent = self.find(y)
if xParent != yParent:
if self.size[xParent] < self.size[yParent]:
self.parent[xParent] = yParent
self.size[yParent] += self.size[xParent]
self.partitions_count -= 1
else:
self.parent[yParent] = xParent
self.size[xParent] += self.size[yParent]
self.partitions_count -= 1
def find(self, element):
if element in self.parent:
if element == self.parent[element]:
return element
root = self.parent[element]
while self.parent[root] != root:
root = self.find(self.parent[root])
self.parent[element] = root
return root
return False
if __name__ == '__main__':
t = disjoint_set()
t.make_set(1)
t.make_set(2)
t.make_set(3)
t.make_set(4)
t.make_set(5)
print("Create 5 singleton partitions")
print(t.partitions_count)
print("Union two singletons into a single partition")
t.union(1,2)
print("Union three singletones into a single partition")
t.union(3,4)
t.union(5,4)
print("Union a single partition")
t.union(2,4)
print("Parent List: %s" % t.parent)
print("Partition Count: %s" % t.partitions_count)
print("Parent of element 2: %s" % t.find(2))
print("Parent List: %s" % t.parent)