https://codingsmu.tistory.com/175
이전 글에서 먼저, 배열에서의 특정 구간의 합을 구할 때 O(N) 시간안에 구할 수 있는 누적합(Prefix Sum)알고리즘을 알아보았습니다. 이번 글에서는 더 개선된 시간인 O(logN)으로 구간합을 구할 수 있는 자료구조인 세그먼트 트리(Segment Tree)에 대해 알아보도록 하겠습니다.
세그먼트 트리(Segment Tree)란?
먼저, arr라는 숫자가 저장되어 있는 배열이 입력으로 들어왔을 때, 구간합을 트리를 활용해 표현해봅시다.
arr = [5,4,3,2,1] 이 저장되어 있다면, 아래와 같이 이진 트리 형태로 세그먼트 트리가 구성됩니다.
위의 그림과 같이, 각 노드에는 배열 인덱스 범위의 구간합과 트리 인덱스를 적어두었습니다.
(계산 편의를 위해 루트노드의 트리 인덱스는 1부터 시작합니다)
각 트리 노드에 구간합을 저장하는 과정을 살펴보도록 하겠습니다.
1. 세그먼트 트리에 구간합 저장하기 (update)
우선, 구간합을 저장할 세그먼트 트리를 0으로 초기화해주어야 합니다.
세그먼트 트리의 크기는 넉넉하게 배열 크기 N의 4배 정도로 선언해주면 됩니다.
tree = [0]*(len(arr)*4)
트리를 초기화 해주었다면, 배열의 첫 번째 원소부터 각 노드에 더해주면 됩니다.
다음과 같이 트리의 leaf node에서 루트노드까지 더해주는 bottom-up 접근 방식으로 진행하면 됩니다.
1. arr[0]부터 트리의 맨 왼쪽의 leaf node에 저장
2. 루트 노드(idx==1)에 도달할때까지 부모 노드로 이동하며, 모두 +arr[0]을 해줌
3. 다음 arr[1]을 같은 방식으로 루트노드 까지 모두 +arr[1]해줌.
4. arr의 마지막 배열까지 반복해주면, 구간합을 저장한 다음의 세그먼트 트리 완성.
이를, 코드로 구현하면 다음과 같은 N개의 배열에 대해 트리의 깊이인 logN만큼의 시간복잡도를 가지는 알고리즘으로 구현할 수 있습니다.
index = start_idx
for i in range(len(arr)): # O(N)
idx = index + i
while idx >= 1: # O(logN)
tree[idx] += arr[i]
idx //= 2
2. 첫 번째 배열의 세그먼트트리 위치 찾기
그렇다면 start_idx는 어떻게 구할까요?
배열의 크기인 N 보다 큰 (1 << n) = 2^n를 만족하는 n을 찾고, 2^n이 바로 start_idx가 됩니다.
이를 구하는 식은, 2의 제곱식으로 비트 연산자인 << 를 이용하면 아래와 같이 비교적 간단하게 구할 수 있습니다.
n = 0
while 2 ** n < len_arr: n += 1
start_idx = 1 << n
3. 특정 구간의 구간합 구하기
세그먼트 트리에서 i~j번째의 구간합을 구하고 싶다면, 다음과 같이 트리의 구간 합이 저장되어 있는 노드를 찾아서 더하기만 하면 됩니다.
3.1. 구간합이 노드 하나를 가리키는 경우
만약, 0~3번째 구간합을 구하고 싶다면, 트리 인덱스로 2를 반환해주면 됩니다.
구간합을 저장한 최종 트리 인덱스를 찾기위해서는 left = 0, right = 3으로 두고 bottom-up 방식으로 찾아가면 됩니다.
실제 트리 index는 start_idx인 8을 각각 더해줘야 하므로, left = 8, right = 11로 시작하면 됩니다.
left, right는 찾고자 하는 구간합 노드에 도달하기 위해 부모 노드로 이동해야 합니다.
left, right가 같은 트리 노드를 가리키면, 해당 노드를 반환해주면 됩니다.
3.2. 구간합이 여러개의 노드의 합일 경우
만약, 1~4번째 구간합을 구하고 싶다면, 트리 인덱스로 9, 5, 12번째의 노드를 더해주면 됩니다.
이 경우 left, right의이동은 다음과 같이 진행됩니다.
left가 짝수일 경우 부모노드의 구간합이 최종적으로 구하고자 하는 구간합에 포함되지만,
이 경우는 left가 홀수이므로, 부모 노드가 구간합에 포함되지 않습니다. 이 경우 현재 노드를 구간합에 저장하고 오른쪽 부모노드로 이동해야합니다. 마찬가지로 right가 짝수일 경우 현재 노드를 구간합에 저장하고 왼쪽 부모노드 이동해야합니다.
left, right가 같은 트리 노드를 가리키면, 해당 노드를 구간합에 저장하고 종료하면 됩니다.
이를 코드로 구현하면, 아래와 같은 logN의 시간복잡도를 가지는 코드로 표현할 수 있습니다.
left, right = i + start_idx, j + start_idx
sub_sum = 0
while left < right:
if left % 2 == 0: # 짝수면 부모 노드로 이동
left //= 2
else: # 홀수면 저장하고, 오른쪽 부모 노드로 이동
sub_sum += tree[left]
left = (left+1) // 2
if right % 2 != 0: # 홀수면 부모 노드로 이동
right //= 2
else: # 짝수면 저장하고, 왼쪽 부모 노드로 이동
sub_sum += tree[right]
right = (right-1) // 2
if left == right: sub_sum += tree[left]
위에서 설명한 세그먼트 트리의 초기화, 업데이트, 구간합 반환에 대한 전체적인 코드는 아래와 같습니다.
def update_tree(tree, start_idx, n):
# 트리 초기화
index = start_idx
for i in range(n):
idx = index + i
while idx >= 1:
tree[idx] += arr[i]
idx // 2
return tree
def get_sub_tree_sum(left, right):
sub_sum = 0
while left < right:
if left % 2 == 0:
left //= 2
else:
sub_sum += tree[left]
left = (left+1)//2
if right % 2 != 0:
right //= 2
else:
sub_sum += tree[right]
right = (right-1)//2
if left == right: sub_sum += tree[left]
return sub_sum
N, M = map(int, input().split()) # 배열의 크기, 구간합 쿼리의 갯수
arr = list(map(int, input().split())) # 배열
# 배열의 첫번째 숫자의 세그먼트 트리 인덱스 찾기
n = 0
while 2 ** n < len_arr: n += 1
start_idx = 1 << n
# 세그먼트 트리 초기화 및 업데이트
tree = [0] * (N * 4)
tree = init_tree(tree, start_idx, N)
for _ in range(M):
i, j = map(int, input().split()) # i~j번째의 구간합
sum = get_sub_tree_sum(i+start_idx, j+start_idx)
print(sum)
'Algorithm > 알고리즘 이론' 카테고리의 다른 글
파이썬으로 구현하는 구간합과 누적합(Prefix sum) (0) | 2024.06.01 |
---|---|
[그래프 탐색] 파이썬으로 구현하는 DFS, BFS (0) | 2023.12.10 |
[문자열] 파이썬으로 구현한 유효한 팰린드롬(Palindrome) (2) | 2023.11.25 |
이진 탐색(Binary Search)과 매개변수 탐색(Parametric Search) (0) | 2022.03.18 |
[그래프 이론] 플로이드-워셜(Floyd-Warshall) 알고리즘 (0) | 2022.01.02 |