코딩/알고리즘 & 자료구조

Segment Tree 심화 (5) - HLD (Heavy-Light Decomposition)

stonejjun 2020. 9. 10. 22:33

이번에 소개할 Segment Tree 심화는 HLD이다. 원래는 Euler Tour 이후에 바로 글을 작성했어야 하는데, 대회 준비등의 이유로 다른 내용 관련 글을 계속 쓰다 보니까 계속 미뤄져서 이제서야 글을 쓰게 되었다. 

사전 지식 (먼저 읽고 와야하는 글)

stonejjun.tistory.com/103 세그 먼트 트리 - euler tour tree 설명글

What is HLD?

그러면 본격적으로 HLD는 무엇일까? HLD는 일단 Heavy - Light Decomposition의 약자이다. 무거운거 - 가벼운거 분해. 즉 Decomposition인데 Heavy Groupr과 Light Group으로 분리한다는 것이다. 그러면 여기서 좀 세부적으로 알아야 할 사항이 생긴다.

1. 뭐를 분리하는 것인가?
2. 무엇을 기준으로 분리하는 것인가?

1에 대한 답은 트리의 간선이다.
2에 대한 답은 조금 어렵다. 

Why HLD?

일단 분류 기준과 구현을 알기 전에, 우리가 왜 HLD를 써야하는지 알아야 이 글을 읽는 필요성이 납득이 될 것이다. HLD가 처리할 수 있는 가장 핵심적인 것은 경로 쿼리이다. 

트리에서 (a,b)가 주어지면 a에서 b로 향하는 경로에 있는 노드의 가중치를 모두 더한 값을 구하는 쿼리가 들어온다고 생각해보자. 두 점의 lca를 c라고 하면 a에서 c까지의 경로의 합 + b에서 c까지의 경로의 합으로 구할 수 있다. 우리가 지금까지 알고 있는 방법은 a에서 c까지 한 칸씩 올라가면서 더하는 것이다. 당연히 쿼리당 O(N)으로 그렇게 좋지 못하다는 것을 알고 있다. 

우리는 간선을 특정 규칙에 따라서 적당히 몇 개의 체인으로 나누어 줄 것이다. 이때 a와 c가 같은 체인이 아니라면 lca까지 올라가는 과정에서 체인으로 이어진 간선은 한 번에 묶어서 타고 올라갈 수 있게 된다는 것이다. 이를 통해서 a에서 c까지 가는 스텝의 횟수가 $O(lgN)$의 시복도를 가짐을 보장할 수 있다. 여기에 전 글에서 작성한 트리를 일자 형식, 세그로 바꾸는 기술을 잘 섞으면 쿼리당 $O(lg^2N)$에 쿼리의 처리가 가능하게 된다. 

분류의 기준

트리의 한 노드 x를 생각하자. x의 자식으로 향하는 간선 중에서 절반 이상의 노드의 갯수를 가지고 있는 트리로 향하는간선을 heavy 간선으로 선택한다. 말로 일단 설명을 하기는 했지만, 말로 설명하기가 너무 어렵고 제대로 전달이 된 것 같지 않아서 예시를 가지고 왔다. 

 

 

 맨 왼쪽 의 1번을 기준으로 생각하면 총 노드의 수는 6개이다. 이때 3을 루트로 하는 서브 트리의 크기는 3으로 6의 절반 이상이다. 따라서 1 - 3의 간선이 heavy 간선 1 - 2 간선이 light 간선이 된다.

가운데에서는 4개의 총 노드 중에서 2번을 루트로 하는 서브트리의 크기가 3이다. 따라서 1-2 간선이 heavy간선이 된다.

가장 오른쪽의 그림과 같은 경우에는 총 노드의 개수가 7개인데 각각을 루트로 하는 서브 트리의 크기가 3,2,1 이다. 따라서 절반 이상이 되는 서브트리가 없으므로 1-2 1-3 1-5 모두 light 간선이다.
하지만, 경우에 따라서 그냥 가장 크기가 큰 서브트리를 담당하는 노드로 향하는 간선을 heavy 간선이라고 하기도 한다. 그러한 경우에는 1-2가 heavy 간선이 되는 것이다. 

당연하게도 한 노드에 대해서 여러개의 heavy 간선은 존재하지 않는다. 

분류와 체인

각 노드에서 heavy edge를 싹다 칠하게 되면 아래와 같은 그림이 나오게 된다. 왼쪽의 그림은 여기 에서 가져왔으며 노드에 번호를 붙여 오른쪽 그림을 만들었다.

 

좀 더 자세히 설명하자면  heavy egde를 모두 구한 다음에, 연결이 되는 heavy edge끼리는 같은 색깔을 칠해주었다. 여기서 같은 색으로 칠해진 간선들이 위에서 말했던 같은 그룹 혹은 같은 체인에 속하는 간선들이라고 한다. 
또한 따로 따로 떨어진 light edge들도 각각 한 개 짜리 길이의 체인이라고 생각을 한다. 

여기서 몇 가지 사실을 통해서 어떤 노드에서 다른 노드로 갈 때 방문하는 체인의 갯수가 O(lgN) scale이라는 것을 증명할 수 있다. 

1. 하나의 heavy chain과 그 다음 heavy chain을 가기전에는 필수적으로 하나의 light chain을 거치게 된다. 만약 그렇지 않다면 이미 두 개의 체인이 하나의 heavy chain으로 이어졌어야 한다.
2. light chain을 타고 올라가면 서브트리의 노드의 수가 두배 이상이 된다. 가장 크지 않은 그룹에서 올라왔으니, 올라가면 크기가 2배 이상이 될 수 밖에 없다.

따라서 최대 lgN 스케일의 chain을 지날 수 밖에 없음이 보장이 된다.

코드

hld는 짠지 조금 오래되기도 했고, 확실히 코드를 기억하고 싶기도 하고, 여러개의 함수로 이루어져 있기 때문에 코드를 세분화시켜서 분석하려고 한다. 기본적으로 hld를 justicehui.github.io/hard-algorithm/2020/01/24/hld/ 에서 배웠기 때문에 이 부분의 내용은 저 블로그 글의 내용과 상당히 흡사하다. 코드도 거의 비슷하다고 할 수 있다.

이번 글의 코드는 이 문제를 푸는 코드 기준으로 한다. 

Part 1

void dfs1(ll x){
	sz[x]=1;
	for(auto n:v[x]){
		if(sz[n]) continue;
		dep[n]=dep[x]+1;
		par[n]=x;
		dfs1(n);
		sz[x]+=sz[n];
		dv[x].push_back(n);
		if(sz[n]>sz[dv[x][0]]) swap(dv[x][0],dv[x].back());
	}
}

sz[i]는 i를 루트로 하는 서브트리의 크기를 담아두는 배열이다. vis대신에 sz를 사용하여 정점에 방문한 적이 있는지 체크했으며, 재귀를 통해 모든 자식노드의 서브트리 크기를 현재 서브트리 크기에 더해준다. dep는 그 점의 레벨(깊이)를 담아두는 배열이다. 이후에 필요하다.
dv에서는 par로 가는 간선을 제외한 나머지 모든 간선을 담아둔다. 그와중에 마지막 줄의 교환을 통해서 자식 중 가장 큰 서브트리로 향하는, 즉 heavy edge를 dv의 맨 앞으로 끌고 와준다.

Part 2

void dfs2(ll x){
	np++;
	in[x]=np;
	for(auto k:dv[x]){
		if(dv[x][0]==k) top[k]=top[x];
		else top[k]=k;
		dfs2(k);
	}
	out[x]=np;
}

굉장히 전형적인 euler tour을 하는 과정이다. 모른다면 위의 추천글을 읽고 오는 것이 좋다. 여기서 5번째, 6번째 줄에 집중을 해야한다. top[i]는 i가 속한 heavy chain의 가장 위에 있는 원소이다.
heavy edge로 내려간 노드는 heavy chain이 계속 이어지는 것이기 때문에 현재 노드의 top이 자식 노드의 top이 된다.
light edge로 내려간 노드는 그 노드가 항상 새로운 heavy chain의 시작점, 즉 top이 된다.

Part 3

node upt(ll idx,ll s,ll e,ll nod){
	if(idx<s||idx>e) return tree[nod];
	if(s==e){
		tree[nod].val=1LL-tree[nod].val;
		if(tree[nod].val) tree[nod].left=s;
		else tree[nod].left=1e18;
		return tree[nod];
	}
	node a=upt(idx,s,(s+e)/2,nod*2);
	node b=upt(idx,(s+e)/2+1,e,nod*2+1);
	tree[nod].val=a.val+b.val;
	tree[nod].left=min(a.left,b.left);
	return tree[nod];
}

ll sol(ll l,ll r,ll s,ll e,ll nod){
	if(r<s||e<l) return 1e18;
	if(l<=s&&e<=r) return tree[nod].left;
	return min(sol(l,r,s,(s+e)/2,nod*2),sol(l,r,(s+e)/2+1,e,nod*2+1));
}

part3는 세그먼트 트리 부분이다. 문제에 따라서 계속 바뀌게 되는데, hld를 공부할 정도면 어떤 종류의 세그를 짜야하는 지에 대해서는 충분히 생각할 수 있다고 생각한다.

Part 4

ll psol(ll a,ll b){
	ll ans = 1e18;
	while(top[a]!=top[b]){
		if(dep[top[a]] > dep[top[b]]) swap(a, b);
		ll x=top[b];
		ans=min(ans,sol(in[x],in[b],1,n,1));
		b=par[x];
	}
	if(dep[a]>dep[b]) swap(a, b);
	ans=min(ans,sol(in[a],in[b],1,n,1)); 
	return ans;
}

두 개의 체인이 다르다면 둘 중에 더 깊이있는 노드를 b라고 한다. b를 b가 속한 heavy chain에 대해서 계산을 해준다. b에서 체인의 가장 위쪽까지를 묶어서 처리를 해주었으므로 b를 top[b]의 부모노드로 바꾼다. 이 과정을 통해서 현재 heavy chain에 대해서는 계산을 완료하고, 다음 heavy chain으로 넘겨주게 된다.
a와 b가 같은 heavy chain에 속해있다면 heavy chain은 세그먼트 트리 상에서 붙어있으므로 한 번에 쿼리를 통해서 처리를 할 수 있다. 

BOJ 13512의 전체 코드는 아래와 같다.

더보기
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
typedef pair<ll,ll> pii;
#define ff first
#define ss second
#define eb emplace_back
#define pb push_back

ll n;
ll par[1010101];
ll vis[1010101];
ll dep[1010101];
ll sz[1010101];
ll in[1010101];
ll out[1010101];
ll top[1010101];
vector<ll> dv[1010101];
vector<ll> v[1010101];
ll rein[1010101];

void dfs1(ll x){
	sz[x]=1;
	for(auto n:v[x]){
		if(sz[n]) continue;
		dep[n]=dep[x]+1;
		par[n]=x;
		dfs1(n);
		sz[x]+=sz[n];
		dv[x].push_back(n);
		if(sz[n]>sz[dv[x][0]]) swap(dv[x][0],dv[x].back());
	}
}

ll np;
void dfs2(ll x){
	np++;
	in[x]=np;
	for(auto k:dv[x]){
		if(dv[x][0]==k) top[k]=top[x];
		else top[k]=k;
		dfs2(k);
	}
	out[x]=np;
}

struct node{
	ll val,left;
};

ll arr[1010101];
node tree[5050505];

node upt(ll idx,ll s,ll e,ll nod){
	if(idx<s||idx>e) return tree[nod];
	if(s==e){
		tree[nod].val=1LL-tree[nod].val;
		if(tree[nod].val) tree[nod].left=s;
		else tree[nod].left=1e18;
		return tree[nod];
	}
	node a=upt(idx,s,(s+e)/2,nod*2);
	node b=upt(idx,(s+e)/2+1,e,nod*2+1);
	tree[nod].val=a.val+b.val;
	tree[nod].left=min(a.left,b.left);
	return tree[nod];
}

ll sol(ll l,ll r,ll s,ll e,ll nod){
	if(r<s||e<l) return 1e18;
	if(l<=s&&e<=r) return tree[nod].left;
	return min(sol(l,r,s,(s+e)/2,nod*2),sol(l,r,(s+e)/2+1,e,nod*2+1));
}

ll psol(ll a,ll b){
	ll ans = 1e18;
	while(top[a]!=top[b]){
		if(dep[top[a]] > dep[top[b]]) swap(a, b);
		ll x=top[b];
		ans=min(ans,sol(in[x],in[b],1,n,1));
		b=par[x];
	}
	if(dep[a]>dep[b]) swap(a, b);
	ans=min(ans,sol(in[a],in[b],1,n,1)); 
	return ans;
}

int main(){
	cin.tie(0);
	cout.tie(0);
	ios_base::sync_with_stdio(0);
	ll i,j,k,l,m,a,b,c,d;
	cin>>n;

	for(i=1;i<n;i++){
		cin>>a>>b;
		v[a].push_back(b);
		v[b].push_back(a);
	}

	dfs1(1);
	top[1]=1;
	dfs2(1);

	for(i=1;i<=4*n;i++)
		tree[i].left=1e18;

	for(i=1;i<=n;i++){
		rein[in[i]]=i;
	}

	cin>>m;
	while(m--){
		cin>>a>>b;
		if(a==1){
			upt(in[b],1,n,1);
		}
		if(a==2){
			ll k=psol(1,b);
			if(k==1e18) cout<<-1<<'\n';
			else cout<<rein[k]<<'\n';
		}
	}
}

추천 문제

트리와 쿼리 1 www.acmicpc.net/problem/13510

트리와 쿼리 3 www.acmicpc.net/problem/13512

국제 메시 기구www.acmicpc.net/problem/17429