코딩/백준 문제 풀이

BOJ 15647 - 로스팅하는 엠마도 바리스타입니다

stonejjun 2020. 3. 5. 04:19

문제 태그 : https://www.acmicpc.net/problem/15647

문제 소개

간선의 가중치가 있는 트리가 주어진다. 1번 노드 부터 N번 노드까지 각 노드에 대해서 다른 모든 노드에서 그 노드까지의 거리의 합을 출력하시오.

문제 풀이

처음에는 하나당 O(logN)에 구하는 방법을 생각했다. HLD나 centriod decomposition을 떠올렸지만, 잘 모르는 쪽이기 때문에 더 이상 풀이를 떠올릴 수 없었다. 분명히 centriod를 활용해 풀 수 있을 것 같긴하다. 그래서 전체를 O(N)으로 구하는 방식으로 바꾸었다. Tree DP를 이용하는 방식이다.

F[i] = 내 서브트리의 노드들이 i번 노드까지 오는데 걸리는 거리의 합
cnt[i] = 내 버스트리의 노드 수
dp[i] = 내 서브트리를 제외한 노드들이 i번 노드까지 오는데 걸리는 거리의 합
dep[i] = 루트노드에서 i번 노드까지의 거리

맨 처음에 DFS로 Tree를 확실하게 만들어준 다음, 아래에서 부터 올라오면서 cnt[i]와 F[i]를 구해준다.
p를 부모노드, s를 자식노드라고 할 때 부모노드와 자식노드간의 dp값의 관계를 살펴보면
기존의 dp[p] 에 전체 - cnt[s] 갯수의 노드가 (dep[s]-dep[p]) 만큼 추가 이동하고, p의 s를 제외한 나머지 서브트리가 s로 이동하는 값을 더해야 한다.
이를 계산하면 dp[s]=F[p]-F[s]+dp[p]+(n-2*cnt[s])*(dep[s]-dep[p]) 이라는 식이 나온다.
각 노드에 대한 답은 dp[i]+f[i]가 될 것이다.

코드 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
typedef pair<ll,ll> pii;
#define ff first 
#define ss second
#define ep emplace_back
ll dep[1010101];
ll par[1010101];
ll cnt[1010101];
ll f[1010101];
ll ans[1010101];
ll dp[1010101];
 
vector<pii> v[1001011];
ll n;
 
void dfs(ll x){
    for(auto nn:v[x]){
        if(dep[nn.ff]) continue;
        dep[nn.ff]=dep[x]+nn.ss;
        par[nn.ff]=x;
        dfs(nn.ff);
    }
}
void dfs2(ll x){
    ans[x]=f[x]+dp[x];
    for(auto nn:v[x]){
        ll s=nn.ff;
        if(ans[s]) continue;
        dp[s]=f[x]-f[s]+dp[x]+(n-2*cnt[s])*(dep[s]-dep[x]);
        dfs2(s);
    }
}
 
 
ll arr[1010101];
 
bool sf(ll a,ll b){
    return dep[a]>dep[b];
}
 
int main(){
    ll i,j,k,l,m;
    scanf("%lld",&n);
    for(i=1;i<n;i++){
        ll a,b,c;
        scanf("%lld %lld %lld",&a,&b,&c);
        v[a].ep(b,c);
        v[b].ep(a,c);
    }
    dep[1]=1;
    dfs(1);
 
    for(i=1;i<=n;i++)
        arr[i]=i;
 
    sort(arr+1,arr+1+n,sf);
 
    for(i=1;i<=n;i++){
        k=arr[i];
        for(auto x:v[k]){
            if(x.ff==par[k]) continue;
            f[k]+=f[x.ff]+cnt[x.ff]*(dep[x.ff]-dep[k]);
            cnt[k]+=cnt[x.ff];
        }
        cnt[k]++;
    }
    dfs2(1);
    for(i=1;i<=n;i++)
        printf("%lld\n",ans[i]);
}
 
 

 

'코딩 > 백준 문제 풀이' 카테고리의 다른 글

BOJ 10903 - Wall construction  (0) 2020.03.11
BOJ 2162 - 선분 그룹  (2) 2020.03.08
BOJ 18292 - NM과 K (2)  (0) 2020.02.29
BOJ 17955 - Max or Min  (0) 2020.02.17
BOJ 1006 - 습격자 초라기  (0) 2020.02.09