개념

분할 정복 (Divide and Conquer) 방식의 정렬

배열을 반으로 나누고, 각각 정렬한 후 병합


시간복잡도

케이스시간복잡도
최선O(n log n)
평균O(n log n)
최악O(n log n)
공간복잡도O(n)

특징

장점

  • 항상 O(n log n): 최악의 경우에도 보장
  • 안정 정렬: 같은 값의 순서 유지
  • 예측 가능한 성능: 데이터 분포에 영향 안 받음
  • 병렬 처리 가능: 독립적인 부분 문제

단점

  • 추가 메모리 필요: O(n) 공간복잡도
  • 작은 데이터에는 비효율: 오버헤드 존재
  • 제자리 정렬 아님: 추가 배열 필요

구현 방법

기본 구현 (Top-Down)

def merge_sort(arr):
    # 기저 조건: 크기가 1 이하면 이미 정렬됨
    if len(arr) <= 1:
        return arr
 
    # 분할 (Divide)
    mid = len(arr) // 2
    left = merge_sort(arr[:mid])
    right = merge_sort(arr[mid:])
 
    # 병합 (Conquer)
    return merge(left, right)
 
def merge(left, right):
    """두 정렬된 배열을 병합"""
    result = []
    i = j = 0
 
    # 두 배열을 비교하며 병합
    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1
 
    # 남은 요소 추가
    result.extend(left[i:])
    result.extend(right[j:])
 
    return result
 
# 테스트
print(merge_sort([38, 27, 43, 3, 9, 82, 10]))
# [3, 9, 10, 27, 38, 43, 82]

제자리 정렬 구현 (In-Place)

def merge_sort_inplace(arr, left, right):
    """제자리 병합 정렬 (공간복잡도 O(n)이지만 입력 배열 수정)"""
    if left >= right:
        return
 
    mid = (left + right) // 2
 
    # 분할
    merge_sort_inplace(arr, left, mid)
    merge_sort_inplace(arr, mid + 1, right)
 
    # 병합
    merge_inplace(arr, left, mid, right)
 
def merge_inplace(arr, left, mid, right):
    """두 부분을 병합"""
    # 임시 배열 생성
    left_part = arr[left:mid + 1]
    right_part = arr[mid + 1:right + 1]
 
    i = j = 0
    k = left
 
    # 병합
    while i < len(left_part) and j < len(right_part):
        if left_part[i] <= right_part[j]:
            arr[k] = left_part[i]
            i += 1
        else:
            arr[k] = right_part[j]
            j += 1
        k += 1
 
    # 남은 요소 복사
    while i < len(left_part):
        arr[k] = left_part[i]
        i += 1
        k += 1
 
    while j < len(right_part):
        arr[k] = right_part[j]
        j += 1
        k += 1
 
# 테스트
arr = [38, 27, 43, 3, 9, 82, 10]
merge_sort_inplace(arr, 0, len(arr) - 1)
print(arr)  # [3, 9, 10, 27, 38, 43, 82]

Bottom-Up 구현 (반복문)

def merge_sort_bottom_up(arr):
    """반복문을 이용한 병합 정렬 (재귀 없음)"""
    n = len(arr)
 
    # 크기 1, 2, 4, 8, ... 순으로 병합
    size = 1
    while size < n:
        for start in range(0, n, size * 2):
            mid = min(start + size - 1, n - 1)
            end = min(start + size * 2 - 1, n - 1)
 
            # 병합
            if mid < end:
                merge_inplace(arr, start, mid, end)
 
        size *= 2
 
    return arr
 
# 테스트
print(merge_sort_bottom_up([38, 27, 43, 3, 9, 82, 10]))
# [3, 9, 10, 27, 38, 43, 82]
 
# 재귀 호출 없이 스택 오버플로우 방지

동작 과정

초기: [38, 27, 43, 3, 9, 82, 10]

분할 단계:
[38, 27, 43, 3, 9, 82, 10]
    ↓
[38, 27, 43, 3] | [9, 82, 10]
    ↓                ↓
[38, 27] | [43, 3]  [9, 82] | [10]
   ↓         ↓         ↓        ↓
[38] [27] [43] [3]  [9] [82]  [10]

병합 단계:
[27, 38] | [3, 43]  [9, 82] | [10]
    ↓                  ↓
[3, 27, 38, 43] | [9, 10, 82]
    ↓
[3, 9, 10, 27, 38, 43, 82]

템플릿

def merge_sort_template(arr):
    # 기저 조건
    if len(arr) <= 1:
        return arr
 
    # 1. 분할 (Divide)
    mid = len(arr) // 2
    left = merge_sort_template(arr[:mid])
    right = merge_sort_template(arr[mid:])
 
    # 2. 정복 (Conquer) - 병합
    return merge_template(left, right)
 
def merge_template(left, right):
    result = []
    i = j = 0
 
    # 두 배열 비교하며 병합
    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1
 
    # 남은 요소 추가
    result.extend(left[i:])
    result.extend(right[j:])
 
    return result

핵심 요약

개념설명
원리분할 정복 (Divide and Conquer)
시간복잡도O(n log n) - 항상 보장
공간복잡도O(n)
안정성안정 정렬
특징최악의 경우에도 O(n log n)

다른 정렬과 비교

정렬시간복잡도 (평균)최악공간복잡도안정성
병합 정렬O(n log n)O(n log n)O(n)안정
퀵 정렬O(n log n)O(n²)O(log n)불안정
힙 정렬O(n log n)O(n log n)O(1)불안정
삽입 정렬O(n²)O(n²)O(1)안정

주의사항

1. 메모리 사용

# 추가 메모리 O(n) 필요
# 메모리가 제한적이면 힙 정렬 고려
 
# 슬라이싱은 새 배열 생성
left = arr[:mid]  # O(n) 공간

2. 작은 배열 최적화

def merge_sort_optimized(arr):
    # 작은 배열은 삽입 정렬 사용
    if len(arr) <= 10:
        return insertion_sort(arr)
 
    # 큰 배열은 병합 정렬
    mid = len(arr) // 2
    left = merge_sort_optimized(arr[:mid])
    right = merge_sort_optimized(arr[mid:])
    return merge(left, right)

3. 재귀 깊이

# 재귀 깊이: O(log n)
# Python 기본 재귀 한도: 1000
# 큰 배열은 재귀 한도 증가 필요
 
import sys
sys.setrecursionlimit(10**6)

자주 나오는 문제 유형

1. 역순 쌍 개수 세기 (Inversion Count)

문제: 배열에서 i < j이지만 arr[i] > arr[j]인 쌍의 개수

def count_inversions(arr):
    """병합 정렬을 이용한 역순 쌍 개수 계산"""
    if len(arr) <= 1:
        return arr, 0
 
    mid = len(arr) // 2
    left, left_inv = count_inversions(arr[:mid])
    right, right_inv = count_inversions(arr[mid:])
 
    merged, split_inv = merge_and_count(left, right)
 
    total_inv = left_inv + right_inv + split_inv
    return merged, total_inv
 
def merge_and_count(left, right):
    """병합하면서 역순 쌍 개수 세기"""
    result = []
    inversions = 0
    i = j = 0
 
    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            result.append(left[i])
            i += 1
        else:
            # right[j]가 선택되면, left의 남은 모든 요소와 역순 쌍
            result.append(right[j])
            inversions += len(left) - i
            j += 1
 
    result.extend(left[i:])
    result.extend(right[j:])
 
    return result, inversions
 
# 테스트
arr = [8, 4, 2, 1]
sorted_arr, count = count_inversions(arr)
print(f"정렬 결과: {sorted_arr}")
print(f"역순 쌍 개수: {count}")  # 6
 
# 시간복잡도: O(n log n)

핵심: 병합 과정에서 교차하는 쌍의 개수 계산


2. K개의 정렬된 배열 병합

문제: K개의 정렬된 배열을 하나의 정렬된 배열로 병합

import heapq
 
def merge_k_sorted_arrays(arrays):
    """K개의 정렬된 배열을 병합"""
    # 최소 힙 생성 (값, 배열 인덱스, 요소 인덱스)
    heap = []
    result = []
 
    # 각 배열의 첫 요소를 힙에 추가
    for i, arr in enumerate(arrays):
        if arr:
            heapq.heappush(heap, (arr[0], i, 0))
 
    # 힙에서 하나씩 꺼내며 결과에 추가
    while heap:
        val, arr_idx, elem_idx = heapq.heappop(heap)
        result.append(val)
 
        # 다음 요소가 있으면 힙에 추가
        if elem_idx + 1 < len(arrays[arr_idx]):
            next_val = arrays[arr_idx][elem_idx + 1]
            heapq.heappush(heap, (next_val, arr_idx, elem_idx + 1))
 
    return result
 
# 테스트
arrays = [
    [1, 4, 7],
    [2, 5, 8],
    [3, 6, 9]
]
print(merge_k_sorted_arrays(arrays))
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
 
# 시간복잡도: O(N log K) - N: 전체 요소 수, K: 배열 개수

핵심: 최소 힙을 사용하여 효율적으로 병합


3. 연결 리스트 병합 정렬

문제: 연결 리스트를 병합 정렬로 정렬

class Node:
    def __init__(self, data):
        self.data = data
        self.next = None
 
def merge_sort_list(head):
    """연결 리스트 병합 정렬"""
    if not head or not head.next:
        return head
 
    # 중간 지점 찾기 (Fast & Slow Pointer)
    slow, fast = head, head.next
    while fast and fast.next:
        slow = slow.next
        fast = fast.next.next
 
    # 리스트를 두 부분으로 분할
    mid = slow.next
    slow.next = None
 
    # 재귀적으로 정렬
    left = merge_sort_list(head)
    right = merge_sort_list(mid)
 
    # 병합
    return merge_lists(left, right)
 
def merge_lists(l1, l2):
    """두 정렬된 연결 리스트 병합"""
    dummy = Node(0)
    current = dummy
 
    while l1 and l2:
        if l1.data <= l2.data:
            current.next = l1
            l1 = l1.next
        else:
            current.next = l2
            l2 = l2.next
        current = current.next
 
    current.next = l1 if l1 else l2
    return dummy.next
 
# 테스트용 함수
def create_list(arr):
    dummy = Node(0)
    current = dummy
    for val in arr:
        current.next = Node(val)
        current = current.next
    return dummy.next
 
def print_list(head):
    values = []
    while head:
        values.append(head.data)
        head = head.next
    print(values)
 
# 테스트
head = create_list([4, 2, 1, 3])
sorted_head = merge_sort_list(head)
print_list(sorted_head)  # [1, 2, 3, 4]
 
# 시간복잡도: O(n log n)
# 공간복잡도: O(log n) - 재귀 스택

핵심: 연결 리스트는 추가 배열 없이 O(log n) 공간으로 가능


4. 외부 정렬 (External Sort)

문제: 메모리보다 큰 파일을 정렬

def external_merge_sort(input_file, output_file, chunk_size):
    """
    외부 정렬 (메모리보다 큰 데이터)
    1. 파일을 chunk_size만큼 읽어 정렬 후 임시 파일 생성
    2. 임시 파일들을 병합
    """
    import tempfile
    import heapq
 
    # Phase 1: 청크로 나눠 정렬
    temp_files = []
    with open(input_file, 'r') as f:
        while True:
            chunk = []
            for _ in range(chunk_size):
                line = f.readline()
                if not line:
                    break
                chunk.append(int(line.strip()))
 
            if not chunk:
                break
 
            # 청크 정렬 후 임시 파일에 저장
            chunk.sort()
            temp_file = tempfile.NamedTemporaryFile(mode='w+', delete=False)
            for num in chunk:
                temp_file.write(f"{num}\n")
            temp_file.seek(0)
            temp_files.append(temp_file)
 
    # Phase 2: K-way 병합
    with open(output_file, 'w') as output:
        # 각 파일의 첫 요소를 힙에 추가
        heap = []
        for i, f in enumerate(temp_files):
            line = f.readline()
            if line:
                heapq.heappush(heap, (int(line.strip()), i))
 
        # 병합
        while heap:
            val, file_idx = heapq.heappop(heap)
            output.write(f"{val}\n")
 
            # 다음 요소 읽기
            line = temp_files[file_idx].readline()
            if line:
                heapq.heappush(heap, (int(line.strip()), file_idx))
 
    # 임시 파일 삭제
    for f in temp_files:
        f.close()
 
# 사용 예시
# external_merge_sort('large_input.txt', 'sorted_output.txt', chunk_size=1000)
 
# 시간복잡도: O(n log n)
# 공간복잡도: O(chunk_size)

핵심: 대용량 데이터를 청크로 나눠 처리


추천 연습 문제

기초

중급


병합 정렬의 시간복잡도 증명

T(n) = 2T(n/2) + O(n)

높이가 log n인 완전 이진 트리:
- 각 레벨에서 O(n) 작업
- 총 레벨 수: log n

∴ T(n) = O(n log n)

분할: O(1)
병합: O(n)
재귀 깊이: O(log n)
→ 총 시간: O(n log n)

언제 사용할까?

사용 가능

  • 안정 정렬 필요: 같은 값의 순서 유지
  • 최악의 경우 보장: O(n log n) 보장
  • 연결 리스트 정렬: 추가 메모리 O(log n)만 필요
  • 외부 정렬: 대용량 파일 정렬
  • 병렬 처리: 독립적인 부분 문제

사용 불가

  • 메모리 제한: O(n) 추가 메모리 불가능
  • 작은 데이터: 삽입 정렬이 더 빠름
  • 캐시 효율 중요: 퀵 정렬이 더 나음

결론: 안정 정렬이 필요하거나 최악의 경우 보장이 중요할 때 최적!