개념

**최소 신장 트리 (MST)**를 구하는 그리디 알고리즘

하나의 정점에서 시작하여 트리를 확장해나가며 가장 작은 간선을 선택


시간복잡도

구현 방식시간복잡도
인접 행렬 + 배열O(V²)
인접 리스트 + 우선순위 큐O(E log V)
공간복잡도O(V + E)

특징

장점

  • 밀집 그래프 효율적: 간선 많으면 빠름
  • 점진적 구성: 시작 정점부터 트리 확장
  • 다익스트라와 유사: 우선순위 큐 사용

단점

  • 시작 정점 필요: 임의의 정점 선택
  • 구현 복잡: 크루스칼보다 복잡

MST란?

  • 최소 신장 트리: 모든 정점을 최소 비용으로 연결
  • 간선 개수: V-1개
  • 사이클 없음: 트리 구조

구현 방법

우선순위 큐 구현 (추천)

import heapq
 
def prim(n, graph, start=0):
    """프림 알고리즘 - 우선순위 큐"""
    visited = [False] * n
    mst = []
    total_weight = 0
 
    # 우선순위 큐: (가중치, 정점, 이전 정점)
    pq = [(0, start, -1)]
 
    while pq:
        weight, current, prev = heapq.heappop(pq)
 
        # 이미 방문한 정점
        if visited[current]:
            continue
 
        visited[current] = True
 
        # MST에 간선 추가 (시작 정점 제외)
        if prev != -1:
            mst.append((prev, current, weight))
            total_weight += weight
 
        # 인접 정점 탐색
        for neighbor, edge_weight in graph[current]:
            if not visited[neighbor]:
                heapq.heappush(pq, (edge_weight, neighbor, current))
 
    return mst, total_weight
 
# 테스트
# graph[정점] = [(인접 정점, 가중치), ...]
graph = [
    [(1, 4), (2, 3)],           # 0번 정점
    [(0, 4), (2, 1), (3, 2)],   # 1번 정점
    [(0, 3), (1, 1), (3, 4)],   # 2번 정점
    [(1, 2), (2, 4)]            # 3번 정점
]
 
mst, total_weight = prim(4, graph)
print(f"MST 간선: {mst}")
print(f"총 가중치: {total_weight}")  # 6
 
# 시간복잡도: O(E log V)
# 공간복잡도: O(V + E)

배열 기반 구현 (밀집 그래프)

def prim_array(n, graph, start=0):
    """배열 기반 프림 - O(V²)"""
    visited = [False] * n
    min_edge = [float('inf')] * n
    min_edge[start] = 0
    parent = [-1] * n
 
    mst = []
    total_weight = 0
 
    for _ in range(n):
        # 미방문 정점 중 최소 가중치 찾기
        min_weight = float('inf')
        min_vertex = -1
 
        for v in range(n):
            if not visited[v] and min_edge[v] < min_weight:
                min_weight = min_edge[v]
                min_vertex = v
 
        if min_vertex == -1:
            break
 
        visited[min_vertex] = True
 
        # MST에 간선 추가
        if parent[min_vertex] != -1:
            mst.append((parent[min_vertex], min_vertex, min_weight))
            total_weight += min_weight
 
        # 인접 정점 갱신
        for neighbor, weight in graph[min_vertex]:
            if not visited[neighbor] and weight < min_edge[neighbor]:
                min_edge[neighbor] = weight
                parent[neighbor] = min_vertex
 
    return mst, total_weight
 
# 테스트
graph = [
    [(1, 4), (2, 3)],
    [(0, 4), (2, 1), (3, 2)],
    [(0, 3), (1, 1), (3, 4)],
    [(1, 2), (2, 4)]
]
 
mst, total_weight = prim_array(4, graph)
print(f"총 가중치: {total_weight}")  # 6
 
# 시간복잡도: O(V²)
# 밀집 그래프에 효율적

인접 행렬 버전

def prim_matrix(matrix, start=0):
    """인접 행렬 기반 프림"""
    n = len(matrix)
    visited = [False] * n
    min_edge = [float('inf')] * n
    min_edge[start] = 0
    parent = [-1] * n
 
    mst = []
    total_weight = 0
 
    for _ in range(n):
        min_weight = float('inf')
        min_vertex = -1
 
        for v in range(n):
            if not visited[v] and min_edge[v] < min_weight:
                min_weight = min_edge[v]
                min_vertex = v
 
        if min_vertex == -1:
            break
 
        visited[min_vertex] = True
 
        if parent[min_vertex] != -1:
            mst.append((parent[min_vertex], min_vertex, min_weight))
            total_weight += min_weight
 
        # 인접 행렬에서 갱신
        for neighbor in range(n):
            weight = matrix[min_vertex][neighbor]
            if weight > 0 and not visited[neighbor] and weight < min_edge[neighbor]:
                min_edge[neighbor] = weight
                parent[neighbor] = min_vertex
 
    return mst, total_weight
 
# 테스트 (인접 행렬, 0은 간선 없음)
matrix = [
    [0, 4, 3, 0],
    [4, 0, 1, 2],
    [3, 1, 0, 4],
    [0, 2, 4, 0]
]
 
mst, total_weight = prim_matrix(matrix)
print(f"총 가중치: {total_weight}")  # 6

동작 과정

그래프:
    0 ---4--- 1
    |         |
    3         2
    |         |
    2 ---1--- 3

시작: 0번 정점

1단계: 0 방문
  인접: 1(가중치 4), 2(가중치 3)
  최소: 2(가중치 3) 선택
  MST: [(0, 2, 3)]
  총 가중치: 3

2단계: 2 방문
  인접: 1(가중치 1), 3(가중치 4)
  최소: 1(가중치 1) 선택
  MST: [(0, 2, 3), (2, 1, 1)]
  총 가중치: 4

3단계: 1 방문
  인접: 3(가중치 2)
  최소: 3(가중치 2) 선택
  MST: [(0, 2, 3), (2, 1, 1), (1, 3, 2)]
  총 가중치: 6

완료: 모든 정점 방문
총 가중치: 6

템플릿

import heapq
 
def prim_template(n, graph, start=0):
    visited = [False] * n
    mst = []
    total_weight = 0
 
    # (가중치, 정점, 이전 정점)
    pq = [(0, start, -1)]
 
    while pq:
        weight, current, prev = heapq.heappop(pq)
 
        if visited[current]:
            continue
 
        visited[current] = True
 
        if prev != -1:
            mst.append((prev, current, weight))
            total_weight += weight
 
        for neighbor, edge_weight in graph[current]:
            if not visited[neighbor]:
                heapq.heappush(pq, (edge_weight, neighbor, current))
 
    return mst, total_weight

핵심 요약

개념설명
원리그리디 - 트리를 점진적으로 확장
시간복잡도O(E log V) - 우선순위 큐
공간복잡도O(V + E)
활용최소 신장 트리 (MST)
핵심 자료구조우선순위 큐

크루스칼과 비교

특성크루스칼프림
시간복잡도O(E log E)O(E log V)
접근간선 중심정점 중심
자료구조Union-Find우선순위 큐
적합한 그래프희소 그래프밀집 그래프
시작점불필요필요

주의사항

1. 연결 그래프 확인

def prim_check_connected(n, graph, start=0):
    """연결 그래프 확인"""
    mst, total_weight = prim(n, graph, start)
 
    if len(mst) == n - 1:
        return mst, total_weight
    else:
        return None, -1  # 연결되지 않음

2. 시작 정점 선택

# 아무 정점이나 선택 가능
# 결과는 동일 (MST 가중치 합)
mst1, w1 = prim(n, graph, start=0)
mst2, w2 = prim(n, graph, start=1)
# w1 == w2 (같은 가중치 합)

3. 우선순위 큐 중복

# 같은 정점이 여러 번 큐에 들어갈 수 있음
# visited 체크로 무시
if visited[current]:
    continue

자주 나오는 문제 유형

1. 기본 MST

문제: 최소 신장 트리의 가중치 합

def minimum_spanning_tree_prim(n, edges):
    """최소 신장 트리"""
    # 인접 리스트 생성
    graph = [[] for _ in range(n)]
    for u, v, weight in edges:
        graph[u].append((v, weight))
        graph[v].append((u, weight))
 
    mst, total_weight = prim(n, graph)
 
    if len(mst) != n - 1:
        return -1  # 연결되지 않음
 
    return total_weight
 
# 테스트
edges = [
    (0, 1, 4), (0, 2, 3), (1, 2, 1),
    (1, 3, 2), (2, 3, 4)
]
 
print(minimum_spanning_tree_prim(4, edges))  # 6

2. 좌표 평면에서 MST

문제: 모든 점을 연결하는 최소 비용

import heapq
import math
 
def distance(p1, p2):
    """유클리드 거리"""
    return math.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
 
def mst_points(points):
    """좌표 점들의 MST"""
    n = len(points)
 
    # 인접 리스트 생성 (완전 그래프)
    graph = [[] for _ in range(n)]
    for i in range(n):
        for j in range(i + 1, n):
            dist = distance(points[i], points[j])
            graph[i].append((j, dist))
            graph[j].append((i, dist))
 
    mst, total_cost = prim(n, graph)
 
    return round(total_cost, 2)
 
# 테스트 (백준 4386)
points = [(1.0, 1.0), (2.0, 2.0), (2.0, 4.0)]
print(mst_points(points))  # 3.41

3. 특정 간선 포함 MST

문제: 특정 간선을 반드시 포함하는 MST

def mst_with_edge(n, graph, must_edge_u, must_edge_v, must_edge_w):
    """특정 간선을 포함하는 MST"""
    # 1. 두 정점을 미리 연결
    visited = [False] * n
    visited[must_edge_u] = True
    visited[must_edge_v] = True
 
    mst = [(must_edge_u, must_edge_v, must_edge_w)]
    total_weight = must_edge_w
 
    # 2. 두 정점에서 시작하여 확장
    pq = []
 
    for neighbor, weight in graph[must_edge_u]:
        if not visited[neighbor]:
            heapq.heappush(pq, (weight, neighbor, must_edge_u))
 
    for neighbor, weight in graph[must_edge_v]:
        if not visited[neighbor]:
            heapq.heappush(pq, (weight, neighbor, must_edge_v))
 
    # 3. 나머지 프림 알고리즘
    while pq:
        weight, current, prev = heapq.heappop(pq)
 
        if visited[current]:
            continue
 
        visited[current] = True
        mst.append((prev, current, weight))
        total_weight += weight
 
        for neighbor, edge_weight in graph[current]:
            if not visited[neighbor]:
                heapq.heappush(pq, (edge_weight, neighbor, current))
 
    return mst, total_weight

4. 네트워크 연결 (최소 비용)

문제: 모든 컴퓨터를 연결하는 최소 비용

def min_cost_connect_all(n, connections):
    """네트워크 연결"""
    graph = [[] for _ in range(n)]
 
    for u, v, cost in connections:
        graph[u].append((v, cost))
        graph[v].append((u, cost))
 
    mst, total_cost = prim(n, graph)
 
    if len(mst) != n - 1:
        return -1  # 연결 불가
 
    return total_cost
 
# 테스트
connections = [
    (0, 1, 1), (1, 2, 1), (2, 0, 1),
    (1, 3, 1), (3, 4, 1)
]
 
print(min_cost_connect_all(5, connections))  # 4

5. 최소 비용으로 모든 도시 연결

문제: 이미 연결된 도시가 있을 때 추가 비용

def mst_with_existing_edges(n, graph, existing_edges):
    """이미 연결된 간선이 있는 MST"""
    visited = [False] * n
    mst = []
    total_weight = 0
 
    # 이미 연결된 간선들로 컴포넌트 생성
    components = list(range(n))
 
    def find(x):
        if components[x] != x:
            components[x] = find(components[x])
        return components[x]
 
    def union(x, y):
        root_x = find(x)
        root_y = find(y)
        if root_x != root_y:
            components[root_y] = root_x
            return True
        return False
 
    # 기존 간선 처리
    for u, v in existing_edges:
        union(u, v)
 
    # 각 컴포넌트의 대표를 방문 처리
    visited_components = set()
    for i in range(n):
        visited_components.add(find(i))
 
    # 프림으로 나머지 연결
    # (구현 생략 - 복잡함)
 
    return mst, total_weight

추천 연습 문제

기초

중급

고급


프림 vs 크루스칼 선택 기준

프림 사용

# 밀집 그래프 (간선 많음)
if E > V * V / 2:
    use_prim()
 
# 우선순위 큐 익숙
# 다익스트라와 유사한 구조

크루스칼 사용

# 희소 그래프 (간선 적음)
if E < V * V / 2:
    use_kruskal()
 
# 간선 정렬이 쉬움
# Union-Find 익숙

언제 사용할까?

사용 가능

  • 최소 신장 트리: MST 문제
  • 밀집 그래프: 간선 수 많음
  • 정점 기준 선택: 정점 중심 접근
  • 다익스트라 익숙: 유사한 구조

사용 불가

  • 희소 그래프: 크루스칼이 더 효율적
  • 간선 기준: 크루스칼 사용

결론: 밀집 그래프의 MST 문제에 최적!