개념

**최소 신장 트리 (Minimum Spanning Tree, MST)**를 구하는 그리디 알고리즘

가중치가 작은 간선부터 선택하며, 사이클이 생기지 않도록 연결


시간복잡도

구현 방식시간복잡도
간선 정렬 + Union-FindO(E log E)
공간복잡도O(V + E)

특징

장점

  • 구현 간단: 간선 정렬 + Union-Find
  • 희소 그래프 효율적: 간선 적으면 빠름
  • 최적해 보장: 그리디로 MST 구함

단점

  • 간선 정렬 필요: O(E log E) 소요
  • 밀집 그래프: 간선 많으면 프림이 나음

MST란?

  • 신장 트리 (Spanning Tree): 그래프의 모든 정점을 포함하는 트리
  • 최소 신장 트리: 간선 가중치의 합이 최소인 신장 트리
  • 간선 개수: V-1개

구현 방법

기본 구현 (Union-Find)

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
 
    def find(self, x):
        """경로 압축"""
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
 
    def union(self, x, y):
        """랭크 기반 합치기"""
        root_x = self.find(x)
        root_y = self.find(y)
 
        if root_x == root_y:
            return False  # 이미 같은 집합
 
        # 랭크가 작은 쪽을 큰 쪽에 합침
        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            self.parent[root_y] = root_x
            self.rank[root_x] += 1
 
        return True
 
def kruskal(n, edges):
    """크루스칼 알고리즘"""
    # 간선을 가중치 기준으로 정렬
    edges.sort(key=lambda x: x[2])
 
    uf = UnionFind(n)
    mst = []
    total_weight = 0
 
    for u, v, weight in edges:
        # 사이클이 생기지 않으면 추가
        if uf.union(u, v):
            mst.append((u, v, weight))
            total_weight += weight
 
            # V-1개 간선이면 종료
            if len(mst) == n - 1:
                break
 
    return mst, total_weight
 
# 테스트
edges = [
    (0, 1, 4),
    (0, 2, 3),
    (1, 2, 1),
    (1, 3, 2),
    (2, 3, 4)
]
 
mst, total_weight = kruskal(4, edges)
print(f"MST 간선: {mst}")
print(f"총 가중치: {total_weight}")  # 6
 
# 시간복잡도: O(E log E)
# 공간복잡도: O(V + E)

간단한 Union-Find 구현

def kruskal_simple(n, edges):
    """간단한 크루스칼"""
    edges.sort(key=lambda x: x[2])
 
    parent = list(range(n))
 
    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]
 
    def union(x, y):
        root_x = find(x)
        root_y = find(y)
 
        if root_x == root_y:
            return False
 
        parent[root_y] = root_x
        return True
 
    mst = []
    total_weight = 0
 
    for u, v, weight in edges:
        if union(u, v):
            mst.append((u, v, weight))
            total_weight += weight
 
    return mst, total_weight
 
# 테스트
edges = [
    (0, 1, 4), (0, 2, 3), (1, 2, 1),
    (1, 3, 2), (2, 3, 4)
]
 
mst, total_weight = kruskal_simple(4, edges)
print(f"총 가중치: {total_weight}")  # 6

동작 과정

그래프:
    0 ---4--- 1
    |         |
    3         2
    |         |
    2 ---1--- 3

간선: [(0,1,4), (0,2,3), (1,2,1), (1,3,2), (2,3,4)]
정렬: [(1,2,1), (1,3,2), (0,2,3), (0,1,4), (2,3,4)]

1단계: (1,2,1) - 가중치 1
  1과 2를 연결
  MST: [(1,2,1)]
  총 가중치: 1

2단계: (1,3,2) - 가중치 2
  1과 3을 연결
  MST: [(1,2,1), (1,3,2)]
  총 가중치: 3

3단계: (0,2,3) - 가중치 3
  0과 2를 연결
  MST: [(1,2,1), (1,3,2), (0,2,3)]
  총 가중치: 6

4단계: (0,1,4) - 가중치 4
  0과 1은 이미 연결됨 (사이클 발생) → 스킵

5단계: (2,3,4) - 가중치 4
  2와 3은 이미 연결됨 (사이클 발생) → 스킵

완료: MST 간선 3개 (V-1개)
총 가중치: 6

템플릿

def kruskal_template(n, edges):
    # 1. 간선 정렬 (가중치 기준)
    edges.sort(key=lambda x: x[2])
 
    # 2. Union-Find 초기화
    parent = list(range(n))
 
    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]
 
    def union(x, y):
        root_x = find(x)
        root_y = find(y)
 
        if root_x == root_y:
            return False
 
        parent[root_y] = root_x
        return True
 
    # 3. 간선 선택
    mst = []
    total_weight = 0
 
    for u, v, weight in edges:
        if union(u, v):
            mst.append((u, v, weight))
            total_weight += weight
 
            if len(mst) == n - 1:
                break
 
    return mst, total_weight

핵심 요약

개념설명
원리그리디 - 가장 작은 간선부터 선택
시간복잡도O(E log E)
공간복잡도O(V + E)
활용최소 신장 트리 (MST)
핵심 자료구조Union-Find

프림과 비교

특성크루스칼프림
시간복잡도O(E log E)O(E log V)
접근간선 중심정점 중심
자료구조Union-Find우선순위 큐
적합한 그래프희소 그래프밀집 그래프
구현간단다익스트라와 유사

주의사항

1. 연결 그래프 확인

def kruskal_check_connected(n, edges):
    """연결 그래프 확인"""
    mst, total_weight = kruskal(n, edges)
 
    # MST 간선 개수가 V-1이면 연결됨
    if len(mst) == n - 1:
        return mst, total_weight
    else:
        return None, -1  # 연결되지 않음

2. 간선 중복

# 같은 간선이 여러 번 있으면
# 정렬 시 가장 작은 가중치가 앞에 옴
edges = [(0, 1, 5), (0, 1, 3), (0, 1, 4)]
edges.sort(key=lambda x: x[2])
# [(0, 1, 3), (0, 1, 4), (0, 1, 5)]

3. Union-Find 최적화

# 경로 압축 + 랭크 기반 합치기
# 시간복잡도: O(α(n)) ≈ O(1)
# α(n): 아커만 함수의 역함수

자주 나오는 문제 유형

1. 기본 MST

문제: 최소 신장 트리의 가중치 합

def minimum_spanning_tree(n, edges):
    """최소 신장 트리"""
    mst, total_weight = kruskal(n, edges)
 
    if len(mst) != n - 1:
        return -1  # 연결되지 않음
 
    return total_weight
 
# 테스트
edges = [
    (0, 1, 4), (0, 2, 3), (1, 2, 1),
    (1, 3, 2), (2, 3, 4), (3, 4, 5)
]
 
print(minimum_spanning_tree(5, edges))  # 11

2. 두 번째로 작은 MST

문제: MST보다 큰 두 번째 MST

def second_minimum_spanning_tree(n, edges):
    """두 번째 MST"""
    # 첫 번째 MST
    edges.sort(key=lambda x: x[2])
    parent = list(range(n))
 
    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]
 
    def union(x, y):
        root_x = find(x)
        root_y = find(y)
        if root_x == root_y:
            return False
        parent[root_y] = root_x
        return True
 
    mst_edges = []
    first_mst = 0
 
    for i, (u, v, weight) in enumerate(edges):
        if union(u, v):
            mst_edges.append(i)
            first_mst += weight
 
    # 각 MST 간선을 제외하고 다시 MST 구하기
    second_mst = float('inf')
 
    for skip_idx in mst_edges:
        parent = list(range(n))
        current_mst = 0
        edge_count = 0
 
        for i, (u, v, weight) in enumerate(edges):
            if i == skip_idx:
                continue
 
            if union(u, v):
                current_mst += weight
                edge_count += 1
 
        if edge_count == n - 1:
            second_mst = min(second_mst, current_mst)
 
    return second_mst if second_mst != float('inf') else -1

3. 최대 신장 트리 (Maximum ST)

문제: 가중치 합이 최대인 신장 트리

def maximum_spanning_tree(n, edges):
    """최대 신장 트리"""
    # 가중치를 음수로 바꿔서 최소 신장 트리
    edges_negated = [(u, v, -weight) for u, v, weight in edges]
 
    mst, total_weight = kruskal(n, edges_negated)
 
    return mst, -total_weight
 
# 또는 내림차순 정렬
def maximum_spanning_tree_v2(n, edges):
    """최대 신장 트리 - 내림차순"""
    edges.sort(key=lambda x: -x[2])  # 큰 것부터
 
    parent = list(range(n))
 
    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]
 
    def union(x, y):
        root_x = find(x)
        root_y = find(y)
        if root_x == root_y:
            return False
        parent[root_y] = root_x
        return True
 
    mst = []
    total_weight = 0
 
    for u, v, weight in edges:
        if union(u, v):
            mst.append((u, v, weight))
            total_weight += weight
 
    return mst, total_weight

4. MST에 포함되는 간선 판별

문제: 특정 간선이 MST에 포함되는지 확인

def is_edge_in_mst(n, edges, target_edge):
    """특정 간선이 MST에 포함되는지"""
    mst, _ = kruskal(n, edges)
 
    target_u, target_v, target_w = target_edge
 
    for u, v, w in mst:
        if (u == target_u and v == target_v and w == target_w) or \
           (u == target_v and v == target_u and w == target_w):
            return True
 
    return False
 
# 또는 직접 확인
def must_be_in_mst(n, edges, target_u, target_v, target_w):
    """반드시 MST에 포함되어야 하는지"""
    # target 간선 제외하고 MST 구하기
    other_edges = [(u, v, w) for u, v, w in edges
                   if not (u == target_u and v == target_v and w == target_w)]
 
    mst_without_target, _ = kruskal(n, other_edges)
 
    # 연결이 안 되거나, 더 비싸지면 필수 간선
    if len(mst_without_target) < n - 1:
        return True
 
    return False

5. 네트워크 연결 비용

문제: 모든 컴퓨터를 연결하는 최소 비용

def minimum_cost_to_connect(n, connections):
    """네트워크 연결"""
    # connections: [(컴퓨터1, 컴퓨터2, 비용), ...]
 
    mst, total_cost = kruskal(n, connections)
 
    if len(mst) != n - 1:
        return -1  # 연결 불가
 
    return total_cost
 
# 테스트 (LeetCode 1584)
connections = [
    (0, 1, 1), (1, 2, 1), (2, 0, 1),
    (1, 3, 1), (3, 4, 1)
]
 
print(minimum_cost_to_connect(5, connections))  # 4

6. 도시 분할 계획

문제: N개의 집을 두 마을로 분할하는 최소 비용

def divide_town(n, edges):
    """마을을 두 개로 분할"""
    # MST를 구한 후 가장 큰 간선 제거
    mst, total_weight = kruskal(n, edges)
 
    # MST 간선 중 가장 큰 가중치
    max_edge_weight = max(weight for u, v, weight in mst)
 
    # 가장 큰 간선을 제거하여 두 개로 분할
    return total_weight - max_edge_weight
 
# 테스트 (백준 1647)
edges = [
    (0, 1, 3), (0, 2, 2), (1, 2, 1),
    (1, 3, 5), (2, 3, 4), (3, 4, 6)
]
 
print(divide_town(5, edges))  # 9 (15 - 6)

추천 연습 문제

기초

중급

고급


언제 사용할까?

사용 가능

  • 최소 신장 트리: MST 문제
  • 희소 그래프: 간선 수 적음
  • 간선 기준 선택: 간선 중심 접근
  • 구현 간단: 빠르게 구현 필요

사용 불가

  • 밀집 그래프: 프림이 더 효율적
  • 최단 경로: 다익스트라 사용

결론: 희소 그래프의 MST 문제에 최적!