Introduction

The Union Find data structure stores a collections of disjoint (non-overlapping) sets and can be used to model connected components in undirected graphs. This data structure can be used to:

  • determine if two vetices belong to the same component
  • detect cycles in a graph
  • find the minimum spanning tree of a graph

Union Find Implementation

Optimized Union Find (Disjoint Set) python implementation with path compression and union by rank.

class UnionFind:
    def __init__(self, size: int) -> None:
        """
            T: O(N)
            S: O(N)
        """
        self.root = [i for i in range(size)]
        self.rank = [1] * size
    
    def find(self, x: int) -> int:
        """
            T: O(a * N)
        """
        p = x
        while p != self.root[p]:
            # Path compression
            self.root[p] = self.root[self.root[p]]
            p = self.root[p]
        return p

    def union(self, x: int, y: int) -> int:
        """
            T: O(a * N)
        """
        a = self.find(x)
        b = self.find(y)

        if a == b:
            return 0

        # Union by rank
        if self.rank[a] >= self.rank[b]:
            self.root[b] = self.root[a]
            self.rank[a] += self.rank[b]
        else:
            self.root[a] = self.root[b]
            self.rank[b] += self.rank[a]
        
        return 1

    def connected(self, x: int, y: int) -> bool:
        """
            T: O(a * N)
        """
        return self.find(x) == self.find(y)

Example 1

uf = UnionFind(10)

uf.union(1,3)
uf.union(2,3)
uf.union(3,0)
uf.union(0,4)
uf.union(4,9)
uf.union(9,5)
# uf.union(9,6)
uf.union(7,6)
uf.union(6,8)

print(uf.root)
print(uf.rank)

# [1, 1, 1, 1, 1, 1, 7, 7, 7, 1]
# [1, 7, 1, 1, 1, 1, 1, 3, 1, 1]