코딩/알고리즘 & 자료구조

알고리즘 & 자료구조 복기글 (6) Segment Tree

stonejjun 2020. 2. 19. 03:51

한 번 배웠다하면 응용되는 부분도 엄청많고 정말 많은 문제에서 다양하게 사용되는 자료구조인 Segment Tree를 소개하려 한다.

예제 - 구간 합 구하기 1

링크는 https://www.acmicpc.net/problem/2042 다.

평소처럼 Segment Tree란? 으로 시작하지 않고 예제를 가지고 오며 시작한 이유는 간단하다. Segment Tree의 역할, 존재의의, 이점등을 가장 잘 설명할 수 있는 문제이기 때문이다. 워낙 유명한 문제인 만큼 풀린사람 수도 엄청나다는 것을 알 수 있다. Segment Tree를 알면서 이 문제를 모르기는 쉽지 않다.

이 문제에서 요구하는 사항은 두 가지다. 
1. 배열의 값을 바꾼다.  2. 구간의 합을 구한다.

시간을 생각하지 않고 가장 쉽게 풀 수 있는 방법은 당연히 그냥 배열을 잡아서 숫자를 바꾸고 구간을 더하는 것이다. 이렇게 하면 1번 쿼리는 O(1)만에 처리해 줄 수 있지만 2번 쿼리가 O(N)이 걸리기 때문에  최악의 경우에 O(QN)의 시간 복잡도로 시간초과가 나게 된다.

이와 반대로 prefix sum을 이용해서 구간의 합을 구하면 2번 쿼리는 O(1)만에 처리해 줄 수 있지만, prefix sum이 n개이기 때문에 1번 쿼리를 처리해 주는 데 O(N)이 걸리게 된다. 결국 이 방법도 최악의 경우 O(QN)으로 시간 초과가 난다. 

위의 두 방식은 한 쿼리를 O(1)에 처리 하는 대신 나머지 한 쿼리의 처리가 O(N)이기 때문에 시간 초과가 난다. 이 둘의 밸런스를 둘다 O(lgN)에 처리할 수 있도록 만들어 주는 자료구조가 Segment Tree이다. 

Segment Tree 란?

구간의 합을 저장하고 있는 트리이다. 1번 값은 1~2번 값의 합에 포함되어 있으며, 1~4번 값의 합에 포함 되어있다.
누가 만약 1~5번 의 구간합을 물어본다면 답변을 1~4번의 구간 합 + 5번의 값으로 말해 줄 수 있다. 이런식으로 답변할 때도 최대 lgN 개의 합, 한 배열의 값도 최대 lgN개의 구간합에 들어있게 저장을 한다.

How It works? 

길이 8의 배열을 잡아보자. 실제로는 그렇지 않지만 이해를 돕기 위해서 2^n으로 잡았다.
배열을 [3,4,1,7,2,8,5,6]이라고 설정하자. 그러면 초기에 다음과 같은 트리가 생성된다.

괄호안에는 그 노드가 담당하고 있는 인덱스의 범위를 이야기하고 있는 것이고, 위의 값은 담당하는 범위내의 구간의 합을 의미한다. 리프가 아닌 어떤 노드의 구간과 구간합은 각각 두 자식노드의 구간을 합친 구간과, 두 자식노드의 구간을 합친 구간합이 된다.

이 상황에서 두 가지 쿼리를 처리하면 어떤 식으로 트리가 변하는 지를 설명하려고 한다. 예를 들어 1 3 4란 쿼리가 들어왔다고 가정을 해보자. 이는 즉, 3번째 값이 +3이 되었다는 소리이다. 따라서 루트에서부터 내려가면서 3번을 포함하고 있는 구간에 +3을 해주면 된다. 

다음과 같이 내려가면서 노드의 값을 바꿔준다면 

이런 형태의 트리가 나오고 이것은 우리가 원했던 배열이 [3,4,4,7,2,8,5,6] 일때의 트리와 일치하다. 

그렇다면 이 형태의 트리에서 구간 합은 어떻게 구할까? 쭉 내려가면서 어떤 노드가 담당하는 구간이 내가 원하는 구간에 완전히 포함되어 있으면 더하고 관련이 없으면 무시하고, 아니면 내려가면 된다.
만약 1~7의 구간의 합을 구하고 싶다고 해보자.

1~7의 구간은 1~4 + 5~6 + 7~7로 표현 할 수 있다. 그래서 그림에서 이렇게 세 노드의 합을 구하면 알아낼 수 있다.

How To Code?

어떻게 코딩을 해야할까? 일단 어떠한 노드는 두 자식의 노드에 영향을 받는다는 것을 알 수 있다. 이를 포인터를 이용할 수 도 있지만, 배열에서 쉽게 트리를 구현 할 수 있는 방법이 있다.

이 그림은 그냥 순서대로 노드에 번호를 매긴것이다. 여기서 자식과 부모의 노드 번호 관계를 잘 보자. 세그트리가 완전 이진트리이기 때문에 나타나는 특징으로 번호x 인 노드에 대해서 좌측 자식 노드는 x*2 우측 자식 노드는 x*2+1이라는 특징이 있다. 이를 이용해 보통 재귀적으로 코드를 짠다.
그렇다면 이제 본격적으로 두 가지 쿼리가 어떻게 처리되는지 자세히 설명하려고 한다. 

Query 1.

1번 쿼리에 대해서 상황은 두 가지가 있을 수 있다. 
i) 업데이트하는 인덱스가 범위 내에 없을 경우
ii) 업데이트하는 인덱스가 범위 내에 있을 경우

i)의 경우에는 그냥 재귀를 끝내면 된다. 전체 범위에 포함되지 않다면 그 세부 구간에도 당연히 포함이 되지 않는다.
ii)의 경우에는 포함을 하고 있으므로, 현재 노드에 변화를 시켜주고 두 자식노드에 이 행동을 그대로 다시 전달해준다. 

이 과정을 계속 반복하면 리프노드에 도달할텐데, 리프노드에서는 ii)의 상황이어도 업데이트만 해주고 자식에게 전달을 하지 않으면 된다. update 함수를 작성한 코드를 보자.

1
2
3
4
5
6
7
8
9
ll tree[4040404];
void upt(ll idx, ll val, ll s ,ll e,ll nod) 
{
    if(e<idx||s>idx) return ;
   tree[nod]+=val;
    if(s==e) return;
    upt(idx,val,s,(s+e)/2,nod*2);
    upt(idx,val,(s+e)/2+1,e,nod*2+1);
}
cs


다섯개의 인자는 각각 업데이트 하고자 하는 인덱스, 변화량, 현재 노드가 담당하고 있는 구간의 시작, 끝, 현재 노드의 번호를 가르킨다. 

4번째 줄은 i)의 경우를 처리해 준다. 현재 노드가 담당하는 범위에 업데이트할 인덱스가 없으면 그냥 종료한다.
5번째 줄은 ii)의 경우를 처리해 준다. i)의 경우가 아닌경우 업데이트를 변화량만큼 해준다.
6번째 줄은 리프노드일때 더이상 자식에게 전달을 하지 않고 재귀함수를 종료하는 역할이다.
7~8번째 줄에서 자식 노드에게 그대로 전달을 해준다.

Query2.

2번 쿼리에 대해서는 상황이 세가지가 나오게 된다.

i) 노드가 담당하는 구간이 합을 구하는 구간과 전혀 겹치지 않는경우
ii) 노드가 담당하는 구간이 합을 구하는 구간에 완벽히 포함이 되는 경우
iii) 노드가 담당하는 구간이 합을 구하는 구간에 일부만 포함이 되는 경우

2번 쿼리도 1번 쿼리처럼 재귀적으로 내려가면서 처리 해 줄 것이다. 
i)의 경우에는 그냥 재귀를 종료해 주면서 0을 반환해 주면 된다.
ii)의 경우에는 전체가 포함이 되므로 노드가 담당하고 있는 구간합을 반환해 주면 된다.
iii)의 경우에는 i),ii)의 대처를 모두 할 수 없으므로, 그냥 자식 두개에 전달을 해준다.

iii)인 경우에만 내려가면서 어떤 노드가 담당하는 구간이 ii)이면 더해주고 i)이면 안 더해주는 식으로 진행된다는 것이다. 코드를 보면 좀 더 이해가 될 것이라고 생각한다. 

1
2
3
4
5
6
7
ll sol(ll l ,ll r, ll s,ll e,ll nod)
{
    if(e<l||s>r) return 0;
    if(l<=s&&e<=r) return tree[nod];
    return sol(l,r,s,(s+e)/2,nod*2)
            +sol(l,r,(s+e)/2+1,e,nod*2+1);
}
cs

 

다섯개의 인자는 차례대로 구하려는 합의 범위 (l,r), 현재 노드가 담당하고 있는 합의 범위 (s,e), 현재 노드의 번호(nod) 이다.

3번째 줄은 i)의 경우로 [l,r]과 [s,e]가 아예 겹치는 구간이 없는경우에 0을 반환한다.
4번째 줄은 ii)의 경우로 [l,r]에 [s,e]가 포함되어 있는 경우, 즉 구간 전체의 합을 반환한다.
5~6번째 줄은 iii)의 경우로 자식노드로 보낸다음에 리턴값을 합쳐서 리턴을 해준다.

Code Full ver.

다음은 언급했었던 BOJ 2042 구간 합 구하기의 전체 코드이다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
const int maxn=3e5+5;
ll tree[4040404];
ll arr[1010101];
ll n,m;
 
void upt(ll idx, ll val, ll s ,ll e,ll nod) 
{
    if(e<idx||s>idx) return ;
    tree[nod]+=val;
    if(s==e)return;
    upt(idx,val,s,(s+e)/2,nod*2);
    upt(idx,val,(s+e)/2+1,e,nod*2+1);
}
 
ll sol(ll l ,ll r, ll s,ll e,ll nod)
{
    if(e<l||s>r) return 0;
    if(l<=s&&e<=r) return tree[nod];
    return sol(l,r,s,(s+e)/2,nod*2)
            +sol(l,r,(s+e)/2+1,e,nod*2+1);
}
 
int main()
{
    ll n,i,j,k,l,m,a,b,c;
    scanf("%lld %lld %lld",&n,&m,&k);
    for(i=1;i<=n;i++)
    {
        scanf("%lld",&arr[i]);
        upt(i,arr[i],1,n,1);
    }
 
    for(i=1;i<=m+k;i++)
    {
       scanf("%lld %lld %lld",&a,&b,&c);
       if(a==1){
               ll dif=c-arr[b];
               arr[b]=c;
               upt(b,dif,1,n,1);
       }
       else{
               printf("%lld\n",sol(b,c,1,n,1));
       }
 
    }
}
cs

추천 문제

BOJ 2042 구간 합 구하기
BOJ 14438 수열과 쿼리 17
BOJ 3653 영화 수집
BOJ 5817 고통받는 난쟁이들
BOJ 7578 공장