** [Disjoint Set Data Structure (Union-Find)](https://ja.wikipedia.org/wiki/%E7%B4%A0%E9%9B%86%E5%90%88%E3%83%87 % E3% 83% BC% E3% 82% BF% E6% A7% 8B% E9% 80% A0) ** is a data structure that classifies elements into groups that do not intersect each other. The standard implementation is the Union-Find tree. I think everyone already knows that.
There are many implementation examples of Union-Find tree in Qiita, but I have been wondering for a long time.
find ()
. Can you write it more clearly and shortly?So I wrote a short implementation in Python using dict and frozenset. There are 10 lines including spaces.
class EasyUnionFind:
def __init__(self, n):
self._groups = {x: frozenset([x]) for x in range(n)}
def union(self, x, y):
group = self._groups[x] | self._groups[y]
self._groups.update((c, group) for c in group)
def groups(self):
return frozenset(self._groups.values())
Let's compare it with the Union-Find tree implementation. The comparison target is [Implementation introduced] in Article with the most likes found by searching for "Union Find Python" on Qiita. (https://www.kumilog.net/entry/union-find).
** The result is that this 10-line implementation is slower **. Also, it seems that the difference increases as the number of elements increases. Sorry.
Element count | Union-Find tree implementation | This 10-line implementation | Ratio of travel time |
---|---|---|---|
1000 | 0.72 seconds | 1.17 seconds | 1.63 |
2000 | 1.46 seconds | 2.45 seconds | 1.68 |
4000 | 2.93 seconds | 5.14 seconds | 1.75 |
8000 | 6.01 seconds | 11.0 seconds | 1.83 |
However, even if it is slow, it is about twice as slow as the Union-Find tree, so it may be useful in some cases.
code:
import random
import timeit
import sys
import platform
class EasyUnionFind:
"""
Implementation using dict and frozenset.
"""
def __init__(self, n):
self._groups = {x: frozenset([x]) for x in range(n)}
def union(self, x, y):
group = self._groups[x] | self._groups[y]
self._groups.update((c, group) for c in group)
def groups(self):
return frozenset(self._groups.values())
class UnionFind(object):
"""
Typical Union-Implementation by Find tree.
https://www.kumilog.net/entry/union-I copied the example implementation of find,
Delete unnecessary member functions this time.groups()Was added.
"""
def __init__(self, n=1):
self.par = [i for i in range(n)]
self.rank = [0 for _ in range(n)]
self.size = [1 for _ in range(n)]
self.n = n
def find(self, x):
if self.par[x] == x:
return x
else:
self.par[x] = self.find(self.par[x])
return self.par[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x != y:
if self.rank[x] < self.rank[y]:
x, y = y, x
if self.rank[x] == self.rank[y]:
self.rank[x] += 1
self.par[y] = x
self.size[x] += self.size[y]
def groups(self):
groups = {}
for i in range(self.n):
groups.setdefault(self.find(i), []).append(i)
return frozenset(frozenset(group) for group in groups.values())
def test1():
"""Test if the results of the two implementations are the same. If there is a difference, an AssertionError is sent."""
print("===== TEST1 =====")
random.seed(20200228)
n = 2000
for _ in range(1000):
elements = range(n)
pairs = [
(random.choice(elements), random.choice(elements))
for _ in range(n // 2)
]
uf1 = UnionFind(n)
uf2 = EasyUnionFind(n)
for x, y in pairs:
uf1.union(x, y)
uf2.union(x, y)
assert uf1.groups() == uf2.groups()
print('ok')
print()
def test2():
"""
Output the time required for two implementations while increasing the number of elements.
"""
print("===== TEST2 =====")
random.seed(20200228)
def execute_union_find(klass, n, test_datum):
for pairs in test_datum:
uf = klass(n)
for x, y in pairs:
uf.union(x, y)
timeit_number = 1
for n in [1000, 2000, 4000, 8000]:
print(f"n={n}")
test_datum = []
for _ in range(1000):
elements = range(n)
pairs = [
(random.choice(elements), random.choice(elements))
for _ in range(n // 2)
]
test_datum.append(pairs)
t = timeit.timeit(lambda: execute_union_find(UnionFind, n, test_datum), number=timeit_number)
print("UnionFind", t)
t = timeit.timeit(lambda: execute_union_find(EasyUnionFind, n, test_datum), number=timeit_number)
print("EasyUnionFind", t)
print()
def main():
print(sys.version)
print(platform.platform())
print()
test1()
test2()
if __name__ == "__main__":
main()
Execution result:
3.7.6 (default, Dec 30 2019, 19:38:28)
[Clang 11.0.0 (clang-1100.0.33.16)]
Darwin-18.7.0-x86_64-i386-64bit
===== TEST1 =====
ok
===== TEST2 =====
n=1000
UnionFind 0.7220867589999997
EasyUnionFind 1.1789850389999987
n=2000
UnionFind 1.460918638999999
EasyUnionFind 2.4546459260000013
n=4000
UnionFind 2.925022847000001
EasyUnionFind 5.142797402000003
n=8000
UnionFind 6.01257184
EasyUnionFind 10.963117657000005
Recommended Posts