코딩/알고리즘 & 자료구조

Segment Tree 심화 (4) - Euler Tour Tree(DFS Numbering Tree)

stonejjun 2020. 7. 6. 12:16

이번 Segment Tree 심화는 Euler Tour Tree이다. 난 DFS Ordering Tree라는 이름으로 먼저 접했고, 이쪽이 좀 더 내용을 설명하기에 적합하다고 생각하지만 DFS Ordering Tree라는 이름은 사용을 거의 안하는 것 같다. 그래도 난 일단 DFS Ordering Tree라는 이름을 기반으로 설명을 하려고 한다. 

사전 지식

Segment Tree, DFS

DFS Ordering Tree

이 테크닉의 전체적인 요점은 트리를 일자로 편다는 것이다. 특히 트리를 일자로 펴서 리프로 만들어 버리고 그 리프들을 이용해 새로 세그먼트 트리를 만드는 것이다. 이때 트리를 일자로 펴는 방식은 이름 그대로 DFS Ordering을 통한 방식을 사용한다. 

When?

우리는 DFS Ordering Tree를 어떨 때 사용하게 될까? 보통 DFS Ordering Tree는 값이 변경되면서 서브트리 쿼리가 주어질 때 사용되어진다. 
1 a 가 주어질 때마다 a번 노드의 서브트리에 있는 노드들의 가중치의 총합을 출력한다.
2 a b가 주어질 때마다 a번 노드의 가중치 값을 b로 바꾼다. 
와 같은 쿼리들을 처리할 수 있다. 물론 세그먼트 트리에서도 가능하듯이 1에서 원하는 값은 최대,최소,그 위치 등등 다양하게 잡을 수 있다. 이때 가장 중요한 점은 값이 바뀌어도 서브트리에 대해서 다음과 같은 작업 수행이 가능하다는 것이다.

How?

위와 같은 쿼리 수행이 왜 되는지를 알기 위해서 DFS Oredering의 방식을 알아야 한다. 사실 딱히 어려울 것도 없다. 주어진 트리의 루트에서부터 dfs를 하면서 i번 노드에 처음으로 방문했을 때 지금까지 방문한 노드의 수를 $S_i$라고 하고, i번 노드에서 가는 모든 방문을 끝냈을 때 지금까지 방문한 노드의 수를 $E_i$라고 한다. 즉, dfs과정에서 아래의 그림과 같이 두 개의 값을 구해나가게 된다. 

'

맨 처음 형태와 같은 트리가 있다고 하자. 우리는 루트인 1번 노드부터 dfs를 돌게 된다. 맨처음에 1번 노드에 방문을 하게 되면 지금까지 방문한 노드의 갯수는 1개이기 때문에 $S_1$은 1이다. 그 다음은 계속해서 2 6 순서대로 들어가게 된다. 2번 노드에 처음 방문한 시점에서 방문한 노드의 갯수는 총 2개이며, 6번 노드에 방문할 때는 총 3개를 방문하게 된다. 따라서 $S_2$는 2로 정해지고 $S_6$은 3으로 정해지게 된다. 

계속해서 dfs를 돌고 있고 현재까지 방문한 노드가 계속해서 3개인 상황에서 6번 노드에서 더이상 방문할 수 있는 노드가 없다. 똑같이 2번 노드에서도 더이상 나아갈 수 있는 노드가 없다. 따라서 $E_2$와 $E_6$ 모두 3으로 값이 정해진다. 그다음 다시 1번 노드에서 3번 노드로 탐색을 들어간다. 이때까지 탐색한 노드는 {1, 2, 3, 6} 총 4개고 $S_3$의 값은 4가 된다. 그 이후로 4번 노드도 탐색을 하게 되고, $S_4$는 5로 값이 정해진다. 4번 노드에서는 더이상 탐색을 할 노드가 없고 따라서 빠져나오게 된다. 이때 $E_4$의 값은 5로 정해진다.  

다시 3으로 나온 후에는 다음으로 탐색할 수 있는 5번 노드에 처음 들어간다. 이때는 모든 노드를 탐색한 이후이므로 $S_5$의 값은 6이 된다. 그다음은 어떤 노드에서도 더이상 탐색할 노드가 존재하지 않게 된다. 따라서 dfs를 계속해서 빠져나오면서 $E_1$, $E_3$, $E_5$ 의 값을 모두 6으로 정해주게 된다.


위와 같은 쿼리 수행을 끝나고 S값과 E값을 모두 살펴보자. 여기서 나타나는 한 가지 가장 중요한 특징은 어떤 노드 x에 대해서 x의 서브트리에 있는 노드들의 S값은 $S_x$와 $E_x$사이의 값을 가지게 된다는 것이다. 
즉 세그트리를 만들때 각 노드에 대해서 S값에 각 노드의 가중치같은 값들을 위치시키고, 어떤 노드 x에 대해서 그 서브트리에서 계산을 해야하는 쿼리가 들어오면 세그먼트 트리에서 $S_x$와 $E_x$ 사이에서 값을 계산을 하면 된다. 

코드

아래의 코드는 BOJ 2820 자동차공장을 푸는 코드이다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
 
#define ff first
#define ss second
#define ep emplace_back
#define eb emplace
#define pb push_back
 
ll s[1010101];
ll e[1010101];
ll arr[1010101];
ll vis[1010101];
vector<ll> v[1010101];
 
struct node{
    ll laz,val;
};
 
node tree[2010101];
 
ll upt(ll l,ll r,ll val,ll s,ll e,ll nod){
    if(tree[nod].laz!=0){
        tree[nod].val+=(e-s+1)*tree[nod].laz;
        if(s!=e){
            tree[nod*2].laz+=tree[nod].laz;
            tree[nod*2+1].laz+=tree[nod].laz;
        }
        tree[nod].laz=0;
    }
 
    if(r<s||e<l) return tree[nod].val;
    if(l<=s&&e<=r){
        tree[nod].laz+=val;
        tree[nod].val+=(e-s+1)*tree[nod].laz;
        if(s!=e){
            tree[nod*2].laz+=tree[nod].laz;
            tree[nod*2+1].laz+=tree[nod].laz;
        }
        tree[nod].laz=0;
        return tree[nod].val;
    }
    return tree[nod].val=upt(l,r,val,s,(s+e)/2,nod*2)+upt(l,r,val,(s+e)/2+1,e,nod*2+1);
}
 
 
ll sol(ll idx,ll s,ll e,ll nod){
    if(tree[nod].laz!=0){
        tree[nod].val+=(e-s+1)*tree[nod].laz;
        if(s!=e){
            tree[nod*2].laz+=tree[nod].laz;
            tree[nod*2+1].laz+=tree[nod].laz;
        }
        tree[nod].laz=0;
    }
    if(idx<s||idx>e) return 0;
    if(s==e) return tree[nod].val;
    return sol(idx,s,(s+e)/2,nod*2)+sol(idx,(s+e)/2+1,e,nod*2+1);
}
 
ll cnt=0;
void dfs(ll x){
    vis[x]=1;
    cnt++;
    s[x]=cnt;
    for(auto k:v[x]){
        if(vis[k]) continue;
        dfs(k);
    }
    e[x]=cnt;
}
 
string ss;
 
 
int main(){
    cin.tie(0);
    ios_base::sync_with_stdio(0);
    ll i,j,k,l,m,n;
    cin>>n>>m;
    cin>>arr[1];
    for(i=2;i<=n;i++){
        ll a;
        cin>>arr[i]>>a;
        v[a].pb(i);
    }
    dfs(1);
    for(i=1;i<=n;i++){
        upt(s[i],s[i],arr[i],1,n,1);
    }
 
    while(m--){
        ll a,b;
        cin>>ss>>a;
        if(ss[0]=='p'){
            cin>>b;
            if(s[a]==e[a]) continue;
            upt(s[a]+1,e[a],b,1,n,1);
        }
        else{
            cout<<sol(s[a],1,n,1)<<'\n';
        }
 
    }
}