개념
세그먼트 트리 - 구간에 대한 질의를 효율적으로 처리하는 트리
배열의 구간 합, 최솟값, 최댓값 등을 빠르게 계산
시간복잡도
연산 | 시간복잡도 | 일반 배열 |
---|---|---|
구간 쿼리 | O(log n) | O(n) |
값 업데이트 | O(log n) | O(1) |
트리 생성 | O(n) | O(1) |
특징
구조
배열: [1, 3, 5, 7, 9, 11]
세그먼트 트리 (구간 합):
36
/ \
9 27
/ \ / \
4 5 16 11
/ \ / \
1 3 7 9
장점
- 빠른 구간 쿼리: O(log n)
- 빠른 업데이트: O(log n)
- 다양한 연산: 합, 최솟값, 최댓값, GCD 등
단점
- 공간복잡도: O(4n) - 배열의 4배
- 구현 복잡도: 상대적으로 어려움
Python 구현
1. 구간 합 세그먼트 트리
class SegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.build(arr, 0, 0, self.n - 1)
def build(self, arr, node, start, end):
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
self.build(arr, left_child, start, mid)
self.build(arr, right_child, mid + 1, end)
self.tree[node] = self.tree[left_child] + self.tree[right_child]
def query(self, left, right):
return self._query(0, 0, self.n - 1, left, right)
def _query(self, node, start, end, left, right):
# 범위를 벗어남
if right < start or end < left:
return 0
# 범위 안에 완전히 포함
if left <= start and end <= right:
return self.tree[node]
# 일부만 포함
mid = (start + end) // 2
left_sum = self._query(2 * node + 1, start, mid, left, right)
right_sum = self._query(2 * node + 2, mid + 1, end, left, right)
return left_sum + right_sum
def update(self, idx, val):
self._update(0, 0, self.n - 1, idx, val)
def _update(self, node, start, end, idx, val):
if start == end:
self.tree[node] = val
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
if idx <= mid:
self._update(left_child, start, mid, idx, val)
else:
self._update(right_child, mid + 1, end, idx, val)
self.tree[node] = self.tree[left_child] + self.tree[right_child]
# 테스트
arr = [1, 3, 5, 7, 9, 11]
seg_tree = SegmentTree(arr)
print(seg_tree.query(1, 3)) # 15 (3 + 5 + 7)
seg_tree.update(1, 10)
print(seg_tree.query(1, 3)) # 22 (10 + 5 + 7)
# 시간복잡도: O(log n)
2. 구간 최솟값 세그먼트 트리
class MinSegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [float('inf')] * (4 * self.n)
self.build(arr, 0, 0, self.n - 1)
def build(self, arr, node, start, end):
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
self.build(arr, left_child, start, mid)
self.build(arr, right_child, mid + 1, end)
self.tree[node] = min(self.tree[left_child], self.tree[right_child])
def query(self, left, right):
return self._query(0, 0, self.n - 1, left, right)
def _query(self, node, start, end, left, right):
if right < start or end < left:
return float('inf')
if left <= start and end <= right:
return self.tree[node]
mid = (start + end) // 2
left_min = self._query(2 * node + 1, start, mid, left, right)
right_min = self._query(2 * node + 2, mid + 1, end, left, right)
return min(left_min, right_min)
def update(self, idx, val):
self._update(0, 0, self.n - 1, idx, val)
def _update(self, node, start, end, idx, val):
if start == end:
self.tree[node] = val
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
if idx <= mid:
self._update(left_child, start, mid, idx, val)
else:
self._update(right_child, mid + 1, end, idx, val)
self.tree[node] = min(self.tree[left_child], self.tree[right_child])
# 테스트
arr = [1, 3, 5, 7, 9, 11]
seg_tree = MinSegmentTree(arr)
print(seg_tree.query(1, 4)) # 3 (min of [3, 5, 7, 9])
seg_tree.update(1, 0)
print(seg_tree.query(1, 4)) # 0
3. Lazy Propagation (구간 업데이트)
class LazySegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.lazy = [0] * (4 * self.n)
self.build(arr, 0, 0, self.n - 1)
def build(self, arr, node, start, end):
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
self.build(arr, 2 * node + 1, start, mid)
self.build(arr, 2 * node + 2, mid + 1, end)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def push(self, node, start, end):
if self.lazy[node] != 0:
self.tree[node] += (end - start + 1) * self.lazy[node]
if start != end:
self.lazy[2 * node + 1] += self.lazy[node]
self.lazy[2 * node + 2] += self.lazy[node]
self.lazy[node] = 0
def update_range(self, left, right, val):
self._update_range(0, 0, self.n - 1, left, right, val)
def _update_range(self, node, start, end, left, right, val):
self.push(node, start, end)
if right < start or end < left:
return
if left <= start and end <= right:
self.lazy[node] += val
self.push(node, start, end)
return
mid = (start + end) // 2
self._update_range(2 * node + 1, start, mid, left, right, val)
self._update_range(2 * node + 2, mid + 1, end, left, right, val)
self.push(2 * node + 1, start, mid)
self.push(2 * node + 2, mid + 1, end)
self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]
def query(self, left, right):
return self._query(0, 0, self.n - 1, left, right)
def _query(self, node, start, end, left, right):
self.push(node, start, end)
if right < start or end < left:
return 0
if left <= start and end <= right:
return self.tree[node]
mid = (start + end) // 2
left_sum = self._query(2 * node + 1, start, mid, left, right)
right_sum = self._query(2 * node + 2, mid + 1, end, left, right)
return left_sum + right_sum
# 테스트
arr = [1, 3, 5, 7, 9, 11]
seg_tree = LazySegmentTree(arr)
print(seg_tree.query(1, 3)) # 15
seg_tree.update_range(1, 3, 10) # [1, 3] 구간에 10 더하기
print(seg_tree.query(1, 3)) # 45
# 시간복잡도: O(log n)
자주 나오는 문제 유형
1. 구간 합 구하기
문제: 배열의 일부 구간 합을 여러 번 구하기
def range_sum_query(arr, queries):
seg_tree = SegmentTree(arr)
result = []
for left, right in queries:
result.append(seg_tree.query(left, right))
return result
# 테스트
arr = [1, 3, 5, 7, 9, 11]
queries = [(1, 3), (2, 5), (0, 5)]
print(range_sum_query(arr, queries))
# [15, 30, 36]
# 시간복잡도: O(n + q log n) - q: 쿼리 개수
2. 구간 최솟값 쿼리
문제: 배열의 구간 최솟값을 여러 번 구하기
def range_min_query(arr, queries):
seg_tree = MinSegmentTree(arr)
result = []
for left, right in queries:
result.append(seg_tree.query(left, right))
return result
# 테스트
arr = [5, 3, 7, 2, 9, 1]
queries = [(0, 2), (1, 4), (2, 5)]
print(range_min_query(arr, queries))
# [3, 2, 1]
# 시간복잡도: O(n + q log n)
3. 구간 업데이트와 쿼리
문제: 구간 업데이트와 구간 쿼리를 반복
def range_operations(arr, operations):
seg_tree = LazySegmentTree(arr)
result = []
for op in operations:
if op[0] == 'update':
_, left, right, val = op
seg_tree.update_range(left, right, val)
else: # query
_, left, right = op
result.append(seg_tree.query(left, right))
return result
# 테스트
arr = [1, 3, 5, 7, 9]
operations = [
('query', 0, 2),
('update', 1, 3, 5),
('query', 0, 2),
('query', 1, 3)
]
print(range_operations(arr, operations))
# [9, 24, 35]
# 시간복잡도: O(n + q log n)
4. 2D 세그먼트 트리
문제: 2차원 배열의 구간 합
class SegmentTree2D:
def __init__(self, matrix):
self.n = len(matrix)
self.m = len(matrix[0])
self.tree = [[0] * (4 * self.m) for _ in range(4 * self.n)]
self.build_x(matrix, 0, 0, self.n - 1)
def build_y(self, matrix, vx, lx, rx, vy, ly, ry):
if ly == ry:
if lx == rx:
self.tree[vx][vy] = matrix[lx][ly]
else:
self.tree[vx][vy] = self.tree[2*vx+1][vy] + self.tree[2*vx+2][vy]
else:
my = (ly + ry) // 2
self.build_y(matrix, vx, lx, rx, 2*vy+1, ly, my)
self.build_y(matrix, vx, lx, rx, 2*vy+2, my+1, ry)
self.tree[vx][vy] = self.tree[vx][2*vy+1] + self.tree[vx][2*vy+2]
def build_x(self, matrix, vx, lx, rx):
if lx != rx:
mx = (lx + rx) // 2
self.build_x(matrix, 2*vx+1, lx, mx)
self.build_x(matrix, 2*vx+2, mx+1, rx)
self.build_y(matrix, vx, lx, rx, 0, 0, self.m - 1)
# 간단한 버전: 1D 세그먼트 트리 사용
def query_2d_simple(matrix, r1, c1, r2, c2):
# Prefix Sum 사용이 더 효율적
prefix = [[0] * (len(matrix[0]) + 1) for _ in range(len(matrix) + 1)]
for i in range(len(matrix)):
for j in range(len(matrix[0])):
prefix[i+1][j+1] = (prefix[i+1][j] + prefix[i][j+1]
- prefix[i][j] + matrix[i][j])
return (prefix[r2+1][c2+1] - prefix[r1][c2+1]
- prefix[r2+1][c1] + prefix[r1][c1])
# 테스트
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
print(query_2d_simple(matrix, 1, 1, 2, 2)) # 28 (5+6+8+9)
추천 연습 문제
기초
중급
고급
핵심 요약
개념 | 설명 |
---|---|
세그먼트 트리 | 구간 쿼리를 O(log n)에 처리 |
구간 쿼리 | 합, 최솟값, 최댓값, GCD 등 |
업데이트 | 값 변경도 O(log n) |
Lazy Propagation | 구간 업데이트를 효율적으로 |
세그먼트 트리 vs 다른 자료구조
자료구조 | 구간 쿼리 | 업데이트 | 공간 | 구현 난이도 |
---|---|---|---|---|
배열 | O(n) | O(1) | O(n) | 쉬움 |
Prefix Sum | O(1) | O(n) | O(n) | 쉬움 |
세그먼트 트리 | O(log n) | O(log n) | O(4n) | 어려움 |
Fenwick Tree | O(log n) | O(log n) | O(n) | 중간 |
세그먼트 트리 변형
1. 구간 합
self.tree[node] = self.tree[left] + self.tree[right]
2. 구간 최솟값
self.tree[node] = min(self.tree[left], self.tree[right])
3. 구간 최댓값
self.tree[node] = max(self.tree[left], self.tree[right])
4. 구간 GCD
import math
self.tree[node] = math.gcd(self.tree[left], self.tree[right])
주의사항
1. 트리 크기
# 안전하게 4배 할당
self.tree = [0] * (4 * n)
2. 인덱싱
# 0-based 인덱스 사용
# 왼쪽 자식: 2 * node + 1
# 오른쪽 자식: 2 * node + 2
3. 범위 확인
# 쿼리 범위가 노드 범위를 벗어나는지 확인
if right < start or end < left:
return 0 # 또는 적절한 기본값
4. Lazy Propagation 사용 시기
# 구간 업데이트가 많을 때만 사용
# 단순 업데이트는 일반 세그먼트 트리가 더 간단
세그먼트 트리 템플릿
기본 구조
class SegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.build(arr, 0, 0, self.n - 1)
def build(self, arr, node, start, end):
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
self.build(arr, 2*node+1, start, mid)
self.build(arr, 2*node+2, mid+1, end)
# 합치기 연산 (합, 최솟값, 최댓값 등)
self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]
def query(self, left, right):
return self._query(0, 0, self.n-1, left, right)
def _query(self, node, start, end, left, right):
if right < start or end < left:
return 0 # 기본값
if left <= start and end <= right:
return self.tree[node]
mid = (start + end) // 2
return (self._query(2*node+1, start, mid, left, right) +
self._query(2*node+2, mid+1, end, left, right))
def update(self, idx, val):
self._update(0, 0, self.n-1, idx, val)
def _update(self, node, start, end, idx, val):
if start == end:
self.tree[node] = val
else:
mid = (start + end) // 2
if idx <= mid:
self._update(2*node+1, start, mid, idx, val)
else:
self._update(2*node+2, mid+1, end, idx, val)
# 합치기 연산
self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]
활용 분야
1. 구간 통계
- 구간 합, 평균
- 구간 최솟값, 최댓값
- 구간 GCD, LCM
2. 온라인 쿼리
- 실시간 데이터 처리
- 스트리밍 통계
3. 알고리즘 최적화
- 동적 프로그래밍 최적화
- 분할 정복 최적화
4. 게임 개발
- 맵 구간 쿼리
- 충돌 감지