개념

세그먼트 트리 - 구간에 대한 질의를 효율적으로 처리하는 트리

배열의 구간 합, 최솟값, 최댓값 등을 빠르게 계산


시간복잡도

연산시간복잡도일반 배열
구간 쿼리O(log n)O(n)
값 업데이트O(log n)O(1)
트리 생성O(n)O(1)

특징

구조

배열: [1, 3, 5, 7, 9, 11]

세그먼트 트리 (구간 합):
           36
         /    \
       9       27
      / \     /  \
     4   5  16   11
    / \     / \
   1  3    7  9

장점

  • 빠른 구간 쿼리: O(log n)
  • 빠른 업데이트: O(log n)
  • 다양한 연산: 합, 최솟값, 최댓값, GCD 등

단점

  • 공간복잡도: O(4n) - 배열의 4배
  • 구현 복잡도: 상대적으로 어려움

Python 구현

1. 구간 합 세그먼트 트리

class SegmentTree:
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)
        self.build(arr, 0, 0, self.n - 1)
 
    def build(self, arr, node, start, end):
        if start == end:
            self.tree[node] = arr[start]
        else:
            mid = (start + end) // 2
            left_child = 2 * node + 1
            right_child = 2 * node + 2
 
            self.build(arr, left_child, start, mid)
            self.build(arr, right_child, mid + 1, end)
 
            self.tree[node] = self.tree[left_child] + self.tree[right_child]
 
    def query(self, left, right):
        return self._query(0, 0, self.n - 1, left, right)
 
    def _query(self, node, start, end, left, right):
        # 범위를 벗어남
        if right < start or end < left:
            return 0
 
        # 범위 안에 완전히 포함
        if left <= start and end <= right:
            return self.tree[node]
 
        # 일부만 포함
        mid = (start + end) // 2
        left_sum = self._query(2 * node + 1, start, mid, left, right)
        right_sum = self._query(2 * node + 2, mid + 1, end, left, right)
 
        return left_sum + right_sum
 
    def update(self, idx, val):
        self._update(0, 0, self.n - 1, idx, val)
 
    def _update(self, node, start, end, idx, val):
        if start == end:
            self.tree[node] = val
        else:
            mid = (start + end) // 2
            left_child = 2 * node + 1
            right_child = 2 * node + 2
 
            if idx <= mid:
                self._update(left_child, start, mid, idx, val)
            else:
                self._update(right_child, mid + 1, end, idx, val)
 
            self.tree[node] = self.tree[left_child] + self.tree[right_child]
 
# 테스트
arr = [1, 3, 5, 7, 9, 11]
seg_tree = SegmentTree(arr)
 
print(seg_tree.query(1, 3))  # 15 (3 + 5 + 7)
seg_tree.update(1, 10)
print(seg_tree.query(1, 3))  # 22 (10 + 5 + 7)
 
# 시간복잡도: O(log n)

2. 구간 최솟값 세그먼트 트리

class MinSegmentTree:
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [float('inf')] * (4 * self.n)
        self.build(arr, 0, 0, self.n - 1)
 
    def build(self, arr, node, start, end):
        if start == end:
            self.tree[node] = arr[start]
        else:
            mid = (start + end) // 2
            left_child = 2 * node + 1
            right_child = 2 * node + 2
 
            self.build(arr, left_child, start, mid)
            self.build(arr, right_child, mid + 1, end)
 
            self.tree[node] = min(self.tree[left_child], self.tree[right_child])
 
    def query(self, left, right):
        return self._query(0, 0, self.n - 1, left, right)
 
    def _query(self, node, start, end, left, right):
        if right < start or end < left:
            return float('inf')
 
        if left <= start and end <= right:
            return self.tree[node]
 
        mid = (start + end) // 2
        left_min = self._query(2 * node + 1, start, mid, left, right)
        right_min = self._query(2 * node + 2, mid + 1, end, left, right)
 
        return min(left_min, right_min)
 
    def update(self, idx, val):
        self._update(0, 0, self.n - 1, idx, val)
 
    def _update(self, node, start, end, idx, val):
        if start == end:
            self.tree[node] = val
        else:
            mid = (start + end) // 2
            left_child = 2 * node + 1
            right_child = 2 * node + 2
 
            if idx <= mid:
                self._update(left_child, start, mid, idx, val)
            else:
                self._update(right_child, mid + 1, end, idx, val)
 
            self.tree[node] = min(self.tree[left_child], self.tree[right_child])
 
# 테스트
arr = [1, 3, 5, 7, 9, 11]
seg_tree = MinSegmentTree(arr)
 
print(seg_tree.query(1, 4))  # 3 (min of [3, 5, 7, 9])
seg_tree.update(1, 0)
print(seg_tree.query(1, 4))  # 0

3. Lazy Propagation (구간 업데이트)

class LazySegmentTree:
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)
        self.lazy = [0] * (4 * self.n)
        self.build(arr, 0, 0, self.n - 1)
 
    def build(self, arr, node, start, end):
        if start == end:
            self.tree[node] = arr[start]
        else:
            mid = (start + end) // 2
            self.build(arr, 2 * node + 1, start, mid)
            self.build(arr, 2 * node + 2, mid + 1, end)
            self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
 
    def push(self, node, start, end):
        if self.lazy[node] != 0:
            self.tree[node] += (end - start + 1) * self.lazy[node]
 
            if start != end:
                self.lazy[2 * node + 1] += self.lazy[node]
                self.lazy[2 * node + 2] += self.lazy[node]
 
            self.lazy[node] = 0
 
    def update_range(self, left, right, val):
        self._update_range(0, 0, self.n - 1, left, right, val)
 
    def _update_range(self, node, start, end, left, right, val):
        self.push(node, start, end)
 
        if right < start or end < left:
            return
 
        if left <= start and end <= right:
            self.lazy[node] += val
            self.push(node, start, end)
            return
 
        mid = (start + end) // 2
        self._update_range(2 * node + 1, start, mid, left, right, val)
        self._update_range(2 * node + 2, mid + 1, end, left, right, val)
 
        self.push(2 * node + 1, start, mid)
        self.push(2 * node + 2, mid + 1, end)
 
        self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
 
    def query(self, left, right):
        return self._query(0, 0, self.n - 1, left, right)
 
    def _query(self, node, start, end, left, right):
        self.push(node, start, end)
 
        if right < start or end < left:
            return 0
 
        if left <= start and end <= right:
            return self.tree[node]
 
        mid = (start + end) // 2
        left_sum = self._query(2 * node + 1, start, mid, left, right)
        right_sum = self._query(2 * node + 2, mid + 1, end, left, right)
 
        return left_sum + right_sum
 
# 테스트
arr = [1, 3, 5, 7, 9, 11]
seg_tree = LazySegmentTree(arr)
 
print(seg_tree.query(1, 3))  # 15
seg_tree.update_range(1, 3, 10)  # [1, 3] 구간에 10 더하기
print(seg_tree.query(1, 3))  # 45
 
# 시간복잡도: O(log n)

자주 나오는 문제 유형

1. 구간 합 구하기

문제: 배열의 일부 구간 합을 여러 번 구하기

def range_sum_query(arr, queries):
    seg_tree = SegmentTree(arr)
    result = []
 
    for left, right in queries:
        result.append(seg_tree.query(left, right))
 
    return result
 
# 테스트
arr = [1, 3, 5, 7, 9, 11]
queries = [(1, 3), (2, 5), (0, 5)]
print(range_sum_query(arr, queries))
# [15, 30, 36]
 
# 시간복잡도: O(n + q log n) - q: 쿼리 개수

2. 구간 최솟값 쿼리

문제: 배열의 구간 최솟값을 여러 번 구하기

def range_min_query(arr, queries):
    seg_tree = MinSegmentTree(arr)
    result = []
 
    for left, right in queries:
        result.append(seg_tree.query(left, right))
 
    return result
 
# 테스트
arr = [5, 3, 7, 2, 9, 1]
queries = [(0, 2), (1, 4), (2, 5)]
print(range_min_query(arr, queries))
# [3, 2, 1]
 
# 시간복잡도: O(n + q log n)

3. 구간 업데이트와 쿼리

문제: 구간 업데이트와 구간 쿼리를 반복

def range_operations(arr, operations):
    seg_tree = LazySegmentTree(arr)
    result = []
 
    for op in operations:
        if op[0] == 'update':
            _, left, right, val = op
            seg_tree.update_range(left, right, val)
        else:  # query
            _, left, right = op
            result.append(seg_tree.query(left, right))
 
    return result
 
# 테스트
arr = [1, 3, 5, 7, 9]
operations = [
    ('query', 0, 2),
    ('update', 1, 3, 5),
    ('query', 0, 2),
    ('query', 1, 3)
]
print(range_operations(arr, operations))
# [9, 24, 35]
 
# 시간복잡도: O(n + q log n)

4. 2D 세그먼트 트리

문제: 2차원 배열의 구간 합

class SegmentTree2D:
    def __init__(self, matrix):
        self.n = len(matrix)
        self.m = len(matrix[0])
        self.tree = [[0] * (4 * self.m) for _ in range(4 * self.n)]
        self.build_x(matrix, 0, 0, self.n - 1)
 
    def build_y(self, matrix, vx, lx, rx, vy, ly, ry):
        if ly == ry:
            if lx == rx:
                self.tree[vx][vy] = matrix[lx][ly]
            else:
                self.tree[vx][vy] = self.tree[2*vx+1][vy] + self.tree[2*vx+2][vy]
        else:
            my = (ly + ry) // 2
            self.build_y(matrix, vx, lx, rx, 2*vy+1, ly, my)
            self.build_y(matrix, vx, lx, rx, 2*vy+2, my+1, ry)
            self.tree[vx][vy] = self.tree[vx][2*vy+1] + self.tree[vx][2*vy+2]
 
    def build_x(self, matrix, vx, lx, rx):
        if lx != rx:
            mx = (lx + rx) // 2
            self.build_x(matrix, 2*vx+1, lx, mx)
            self.build_x(matrix, 2*vx+2, mx+1, rx)
        self.build_y(matrix, vx, lx, rx, 0, 0, self.m - 1)
 
# 간단한 버전: 1D 세그먼트 트리 사용
def query_2d_simple(matrix, r1, c1, r2, c2):
    # Prefix Sum 사용이 더 효율적
    prefix = [[0] * (len(matrix[0]) + 1) for _ in range(len(matrix) + 1)]
 
    for i in range(len(matrix)):
        for j in range(len(matrix[0])):
            prefix[i+1][j+1] = (prefix[i+1][j] + prefix[i][j+1]
                               - prefix[i][j] + matrix[i][j])
 
    return (prefix[r2+1][c2+1] - prefix[r1][c2+1]
            - prefix[r2+1][c1] + prefix[r1][c1])
 
# 테스트
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
print(query_2d_simple(matrix, 1, 1, 2, 2))  # 28 (5+6+8+9)

추천 연습 문제

기초

중급

고급


핵심 요약

개념설명
세그먼트 트리구간 쿼리를 O(log n)에 처리
구간 쿼리합, 최솟값, 최댓값, GCD 등
업데이트값 변경도 O(log n)
Lazy Propagation구간 업데이트를 효율적으로

세그먼트 트리 vs 다른 자료구조

자료구조구간 쿼리업데이트공간구현 난이도
배열O(n)O(1)O(n)쉬움
Prefix SumO(1)O(n)O(n)쉬움
세그먼트 트리O(log n)O(log n)O(4n)어려움
Fenwick TreeO(log n)O(log n)O(n)중간

세그먼트 트리 변형

1. 구간 합

self.tree[node] = self.tree[left] + self.tree[right]

2. 구간 최솟값

self.tree[node] = min(self.tree[left], self.tree[right])

3. 구간 최댓값

self.tree[node] = max(self.tree[left], self.tree[right])

4. 구간 GCD

import math
self.tree[node] = math.gcd(self.tree[left], self.tree[right])

주의사항

1. 트리 크기

# 안전하게 4배 할당
self.tree = [0] * (4 * n)

2. 인덱싱

# 0-based 인덱스 사용
# 왼쪽 자식: 2 * node + 1
# 오른쪽 자식: 2 * node + 2

3. 범위 확인

# 쿼리 범위가 노드 범위를 벗어나는지 확인
if right < start or end < left:
    return 0  # 또는 적절한 기본값

4. Lazy Propagation 사용 시기

# 구간 업데이트가 많을 때만 사용
# 단순 업데이트는 일반 세그먼트 트리가 더 간단

세그먼트 트리 템플릿

기본 구조

class SegmentTree:
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)
        self.build(arr, 0, 0, self.n - 1)
 
    def build(self, arr, node, start, end):
        if start == end:
            self.tree[node] = arr[start]
        else:
            mid = (start + end) // 2
            self.build(arr, 2*node+1, start, mid)
            self.build(arr, 2*node+2, mid+1, end)
            # 합치기 연산 (합, 최솟값, 최댓값 등)
            self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]
 
    def query(self, left, right):
        return self._query(0, 0, self.n-1, left, right)
 
    def _query(self, node, start, end, left, right):
        if right < start or end < left:
            return 0  # 기본값
        if left <= start and end <= right:
            return self.tree[node]
        mid = (start + end) // 2
        return (self._query(2*node+1, start, mid, left, right) +
                self._query(2*node+2, mid+1, end, left, right))
 
    def update(self, idx, val):
        self._update(0, 0, self.n-1, idx, val)
 
    def _update(self, node, start, end, idx, val):
        if start == end:
            self.tree[node] = val
        else:
            mid = (start + end) // 2
            if idx <= mid:
                self._update(2*node+1, start, mid, idx, val)
            else:
                self._update(2*node+2, mid+1, end, idx, val)
            # 합치기 연산
            self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]

활용 분야

1. 구간 통계

  • 구간 합, 평균
  • 구간 최솟값, 최댓값
  • 구간 GCD, LCM

2. 온라인 쿼리

  • 실시간 데이터 처리
  • 스트리밍 통계

3. 알고리즘 최적화

  • 동적 프로그래밍 최적화
  • 분할 정복 최적화

4. 게임 개발

  • 맵 구간 쿼리
  • 충돌 감지