본문 바로가기

Algorithms/알고리즘 개념

[알고개념] 1. Segment Tree

1. 알고리즘 개념

  • 세그먼트 트리: 각 노드가 구간 정보를 가지고 있는 트리를 의미한다.
  • 루트 노드는 전체 구간을 가지고, 리프 노드는 길이가 1인 구간들을 가진다.
  • 리프 노드를 제외한 다른 모든 노드는 항상 2개의 자식을 가진다. -> 완전 이진 트리 꼴
  • 리프 노드가 $N$개인 완전 이진 트리의 전체 노드의 개수는 $2N-1$이고, 높이는 $\lceil log\ N \rceil$ 이다.
  • $N$이 2의 제곱꼴이 아닐 경우에는 남는 구간을 기본값으로 채워서 사용한다.

https://blog.naver.com/kks227/220791986409

 

2. 알고리즘 구현

  • start == end : 리프 노드는 구간 길이가 1이므로 바로 값을 대입한다.
  • node의 왼쪽 자식은 node*2이고, 오른쪽 자식은 node*2+1이다.
  • node에 저장된 구간이 [start,end]라면, 왼쪽 자식은 [start, (start+end)/2], 오른쪽 자식은 [(start+end)/2+1, end]가 저장된 구간이다.
  • tree[node]에 저장될 값을 구하려면 왼쪽 자식에 저장된 값 tree[node*2], 오른쪽 자식에 저장된 값 tree[node*2+1]을 먼저 구해야 하므로 재귀함수를 호출한다.
  • Tip: 값의 개수를 n이라 할 때 세그먼트 트리 배열의 크기는 4*n보다 항상 작다. (세그먼트 트리는 완전 이진 트리이므로 리프 노드가 2의 제곱 개수로 기본값이 생성될 경우도 존재하는데, 4*n 만큼 할당해두면 메모리가 충분하다.)
// https://book.acmicpc.net/ds/segment-tree
// a: 배열 A
// tree: 세그먼트 트리
// node: 노드 번호
// node에 저장되어 있는 합의 범위가 start - end
void init(vector<long long> &a, vector<long long> &tree, int node, int start, int end) {
    if (start == end) {
        tree[node] = a[start];
    } else {
        init(a, tree, node*2, start, (start+end)/2);
        init(a, tree, node*2+1, (start+end)/2+1, end);
        tree[node] = tree[node*2] + tree[node*2+1];
    }
}

 

2042번. 구간합 구하기 문제에서 세그먼트 트리를 응용하면 아래와 같이 코드를 짤 수 있다. 여기선 리프 노드들의 값을 먼저 입력받고난 후 반복문을 이용하여 부모 노드들의 값을 채웠다.

  1. 4*n 크기의 트리 배열 생성
  2. 리프 노드 인덱스를 찾아서 리프 노드들에 값 대입
  3. 가장 아래에 있는 부모 노드부터 vec[i] = vec[i*2] + vec[i*2+1] 로 값 대입
  4. sum(L, R, nodeNow, nodeL, nodeR): 구간 [L, R] 의 구간합을 구하는 함수 (첫 시작 노드는 루트, 재귀를 거듭하면서 구간을 반씩 줄이면서 다음 구간이 포함되는지 확인)
  5. update(idx, value): idx번째 값을 value로 업데이트하는 함수. 본인 위의 부모들 (2로 나눈 인덱스들)의 값도 모두 갱신해준다.
#include <bits/stdc++.h>

using namespace std;

int n, m, k, treeSize, nodeSize;
vector<long long> vec;

// Segment tree : 구간합 구하기
// L, R: 구하려는 구간 / nodeNow: 현재 노드 / nodeL, nodeR: 그 노드가 포함되는 구간
long long sum(int L, int R, int nodeNow, int nodeL, int nodeR) {
	// cout << "[ " << nodeL << ", " << nodeR << " ]\n";
	// 1. 전혀 안겹치는 경우
	if (R < nodeL || nodeR < L) return 0; 

	// 2. 완전 포함되는 경우
	if (L <= nodeL && nodeR <= R) return vec[nodeNow];

	// 3. 그 외의 경우
	int mid = (nodeL + nodeR) / 2;
	return sum(L, R, nodeNow * 2, nodeL, mid) + sum(L, R, nodeNow * 2 + 1, mid + 1, nodeR);
}

void update(int idx, long long value) {
	idx += nodeSize;

	vec[idx] = value;
	while (idx > 1) {
		idx /= 2;
		vec[idx] = vec[idx * 2] + vec[idx * 2 + 1];
	}
}

int main() {
	cin.tie(NULL);	cout.tie(NULL);	std::ios::sync_with_stdio(false);

	cin >> n >> m >> k;
	vec.resize(4 * n, 0);
	treeSize = ceil(log2(n));
	nodeSize = pow(2, treeSize) - 1;

	for (int i = nodeSize + 1; i < nodeSize + 1 + n; i++) 
		cin >> vec[i];

	for (int i = nodeSize; i >= 1; i--) 
		vec[i] = vec[i * 2] + vec[i * 2 + 1];

	for (int i = 0; i < m + k; i++) {
		long long a, b, c;
		cin >> a >> b >> c;
		if (a == 1) update(b, c);
		else cout << sum(b, c, 1, 1, nodeSize + 1) << "\n";
		
	}
}

 

3. 시간 복잡도

세그먼트 트리를 구축하는데 걸리는 시간은 $O(n)$, 세그먼트 트리를 응용하여 구간합을 구하는데 걸리는 시간은 $O(log\ n)$이다.