코딩/알고리즘 & 자료구조

Segment Tree With Lazy Propagtion

stonejjun 2020. 2. 29. 11:42

이 글은 Segment Tree에 대한 이해가 어느정도 완벽히 되어있다고 가정하고 쓰는 글입니다. Top-Down 방식으로 구현이 되며, Top-Down 세그트리에 대한 이해가 부족하면 이 글을 먼저 읽는 것을 추천합니다.

예제 - 구간 합 구하기 2

링크는 https://www.acmicpc.net/problem/10999 이다.
이번에도 문제에 대해서 이야기 하면서 Lazy Propagation에 대해서 소개해보려고 한다.

이 문제에서 요구하는 사항은 두 가지 이다.
1. 배열 중 구간에 일정 값을 더한다 2. 배열의 구간의 합을 구한다.

Segment Tree를 생각해보자. 2번 쿼리는 기존 Segment Tree에 존재하는 기능이다. 하지만 여기서 문제가 되는 것은 1번 쿼리이다. 기존의 방식으로 Update를 하게 되면 N번 업데이트를 해야하기 때문에 한 번 쿼리를 처리하는데 O(NlogN) 총 O(QNlogN)의 시간복잡도를 가지게 된다.

해결 방법 : 1번 쿼리도 2번 쿼리처럼 처리하자!

다음과 같은 트리가 있다고 해보자. 1~8인 구간의 합을 구할 때 우리는 각각의 합을 구하는 대신 1~5,6~7,8~8의 구간의 합을 더해줌으로써 구했다. 
똑같이 1~8까지 k를 더할때, 우리는 8번 내려가면서 k를 더하는 방법대신에 1~5,6~7,8~8의 세 구간에 k를 더하겠다는 암시를 해둔다.

Lazy Propagtion

하지만 결국 다 바꾸게 된다면 시간복잡도는 쿼리 한 개당 O(NlogN)이다. 하지만 위에서 '암시'라는 단어를 사용했고, 이름도 'Lazy' Propagation이다. 쿼리에서 해주는 역할은 '암시'뿐이다. 실제로 맨 아래 리프노드의 값을 변화시키지는 않고 있다가 최대한 게으르게, 그 노드의 값을 사용해야 될 때가 되서야 그 값을 계산하면서 추가해 준다. 즉, O(log N)인 2번 쿼리를 처리할 때 내려가면서 같이 덤으로 얹어서 계산을 하겠다는 것이다.

How It Works?

지난번에 구간 합을 구할 때 처럼 내가 업데이트 하려고 하는 구간에 현재 보는 노드가 완벽히 포함될 때 그 노드에 암시를 해 주고 오면 된다. 

위의 트리에서 1~6까지 3을 더하는 쿼리를 처리한다고 해보자. 

그러면 이렇게 두 개의 노드에 표시를 해주면 된다. 이후 이 표시가 어떻게 되는지를 알아보자.

그런데 여기서 문제가 생겼다. 다음 입력에 1부터 9까지의 합을 물어보면 나는 바로 28이라고 대답을 할 것이다. 하지만 이는 답이 아니다. 이 문제는 '암시'를 해주면서 그 위로는 미리 계산을 해 둠으로서 구할 수 있다.

이때 중요한 사실은 '어떤 노드가 담당하고 있는 구간 모두가 변한다면 그 노드도 어떻게 변할지 알고있다.' 라는 것이다.
예를 들어 원래 1~5까지의 구간합은 15였지만, 그 구간은 모두 +3이 되므로 이후 1~5까지의 구간합은 아무튼 15+(5-1+1)*3=30 이 된다. 이 값을 위로 반환해서 변화하는 노드 위로는 미리 값을 계산한다. 다시 말해서 게으르게 전파되는 것은 아래쪽으로만 이라는 것이다. 

또한 전체 구간이 아니라 아래의 세부 값들에도 모두 업데이트 해야 하므로 아래로 전파 시켜줘야 된다. 따라서 업데이트 후 정확한 트리의 모습은 다음과 같다.

코드 

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
 
ll arr[1010101];
struct Node{
    ll val,laz; // 그 값, 암시량
};
 
Node tree[4040404];
 
void uptr(ll dif,ll l,ll r,ll s,ll e,ll nod)
{
    if(tree[nod].laz!=0){ // 구간 변화량이 있으면
        tree[nod].val+=(e-s+1)*tree[nod].laz; // 현 노드의 값 바로 계산
        if(s!=e){                              //가능하면 아래로 전파
            tree[nod*2].laz+=tree[nod].laz; 
            tree[nod*2+1].laz+=tree[nod].laz;
        }
        tree[nod].laz=0;
    }
 
    if(r<s||e<l) return;
 
    if(l<=s&&e<=r){
        tree[nod].val+=(e-s+1)*dif;
        if(s!=e){
            tree[nod*2].laz+=dif;
            tree[nod*2+1].laz+=dif;
        }
        return;
    }
    uptr(dif,l,r,s,(s+e)/2,nod*2);
    uptr(dif,l,r,(s+e)/2+1,e,nod*2+1);
    tree[nod].val = tree[nod*2].val+tree[nod*2+1].val; 
    //윗 노드들은 아래 노드의 변화를 바로 
}
 
ll sol(ll l,ll r,ll s,ll e,ll nod)
{
    if(tree[nod].laz!=0){
        tree[nod].val += (e-s+1)*tree[nod].laz;
        if(s!=e){
            tree[nod*2].laz+=tree[nod].laz;
            tree[nod*2+1].laz+=tree[nod].laz;
        }
        tree[nod].laz=0;
    }
 
    if(r<s||e<l) return 0;
    if(l<=s&&e<=r) return tree[nod].val;
    return sol(l,r,s,(s+e)/2,nod*2)+
            sol(l,r,(s+e)/2+1,e,nod*2+1);
}
 
int main()
{
    ll i,j,k,l,m,n;
    scanf("%lld %lld %lld",&n,&m,&k);
 
    for(i=1;i<=n;i++)
        cin>>arr[i];
    init(1,n,1);
 
    for(i=1;i<=m+k;i++)
    {
        ll t1,a,b,c; 
        scanf("%lld",&t1);
        if(t1==1){ 
            scanf("%lld %lld %lld",&a,&b,&c);
            uptr(c,a,b,1,n,1); 
        }
        else
            cin>>a>>b;
            cout<<sol(a,b,1,n,1)<<'\n';
        }
 
    }
}
 
 
cs

 

이번 글은 기본적으로 개념 자체가 나에게 직관적이지 못해서 완벽하게 설명을 하지 못했던 것 같다. 실제로 최근이 되어서야 Lazy를 그래도 좀 잘 다룬다라고 말할 수 있게 되었다. 하지만, 정말 중요한 개념 중 하나이기 때문에 천천히 내용, 코드를 자세히 살펴 보면 좋을 것 같다.