코딩/알고리즘 & 자료구조

Segment Tree 응용 (3) - 구간 내 부분 합 최대 (금광 세그)

stonejjun 2020. 6. 7. 23:28

이번에 포스팅할 Segment Tree의 활용법은 구간 내에서 최대 부분합을 찾는 방법이다. 흔히 "금광 세그"라고 불리는 방식이며, 금광이라는 문제에 이 기법이 사용되어졌다. 

어떠한 문제에서 두가지 쿼리가 주어졌다고 하자.
1번 쿼리로 수열에서 a번째 숫자를 b로 바꾸는 쿼리가 들어온다.
2번 쿼리로 l,r이 주어졌을때 수열의 [l,r]내에서 연속된 부분합의 최댓값을 구하는 것이다. 
이러한 문제를 Segment Tree를 활용해서 해결할 것이다. 

어떤 노드에 그냥 최대 부분합만 담고 있다고 생각해보자. 아래 두 개의 노드를 합칠때 그냥 둘 중 최댓값을 고르게 되면 왼쪽 절반 안에 속해있는 부분합과 오른쪽 절반안에 속해있는 부분합 밖에 알 수 없다. 두 구간을 걸친 부분들의 합은 고려해 줄 수 없는 것이다. 따라서 아래와 같은 방식을 채택한다.

세그먼트 트리의 어떤 노드에는 4개의 수를 담는다.
Lval = 노드가 담당하는 구간의 왼쪽 값을 포함 하는 최대 부분합.
Rval = 노드가 담당하는 구간의 오른쪽 값을 포함 하는 최대 부분합.
val = 노드가 담당하는 구간의 최대 부분합
all = 노드가 담당하는 구간의 전체 합

어떤 노드 X의 두 자식 노드를 L과 R이라고 하자. 그러면 두 자식노드를 합치는 과정은 다음과 같이 진행된다.

$X_{Lval}=max(L_{Lval},L_{all}+R_{Lval})$
$X_{Rval}=max(R_{Rval},R_{all}+L_{Rval})$
$X_{val}=max(L_{val},R_{val},L_{Rval}+R_{Lval})$
$X_{all}=L_{all}+R_{all}$

이런식으로 자식의 노드로부터 부모의 노드의 모든 값을 완벽히 정의 할 수 있다. 따라서 세그먼트 트리로써 구간에 대한 쿼리를 처리할 수 있다. 

물론 그냥 이렇게 4개의 값을 저장하면 된다. 관계는 이런식으로 나온다. 라고 말은 했지만, 그림을 그리거나 직접 계산을 해서 한 번 몸으로 체감해 보고 좀 더 생각을 해보는 것도 좋다. 곱씹어보면 볼수록 굉장히 신박하고 좋은 방법이라고 생각한다. 

코드

아래 코드는 BOJ 15561 구간합 최대 ? 2 의 정답 코드이다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<bits/stdc++.h>
using namespace std;
const int inf=1e9;
int n,q,u,v;
int arr[1010100];
 
struct nod{
    int allg,lg,rg,midg;
};
nod tree[4040400];
 
nod update(int idx,int val,int node, int l, int r)
{
    if(idx<l||idx>r)
    {
        return tree[node];
    }
    if(l==r)
    {
        tree[node]={val,val,val,val};
        return tree[node];
    }
    nod a = update(idx,val,node*2,l,(l+r)/2);
    nod b = update(idx,val,node*2+1,((l+r)/2)+1,r);
    tree[node].lg=max(a.lg,a.allg+b.lg);
    tree[node].rg=max(b.rg,a.rg+b.allg);
    tree[node].allg=a.allg+b.allg;
    tree[node].midg=max(max(a.midg,b.midg),a.rg+b.lg);
    return tree[node];
}
 
nod ans(int node, int l, int r, int st, int en) {
    if (r < st || en < l)
    {
        nod a={0,-inf,-inf,-inf};
        return a;
    }
    if (st <= l && r <= en)  return tree[node];
    
    nod a = ans(node*2,l,(l+r)/2,st,en);
    nod b = ans(node*2+1,(l+r)/2+1,r,st,en);
    nod c;
    c.lg=max(a.lg,a.allg+b.lg);
    c.rg=max(b.rg,a.rg+b.allg);
    c.allg=a.allg+b.allg;
    c.midg=max(max(a.midg,b.midg),a.rg+b.lg);
    return c;
}
 
int main()
{
    scanf("%d %d %d %d"&n, &q, &u, &v);
    for(int i=1;i<=4*n;i++)
    {
        tree[i].allg=0;
        tree[i].lg=-inf;
        tree[i].midg=-inf;
        tree[i].rg=-inf;
    }
    for (int i=1;i<=n;i++) {
        scanf("%d"&arr[i]);
        update(i, u * arr[i] + v, 11, n);
    }
 
    while (q--)
    {
        int c, a, b;
        scanf("%d %d %d"&c, &a, &b);
        if (c==0)
            printf("%d\n", ans(11, n, a, b).midg - v);
        else
        {
            arr[a] = b;
            update(a, u * b + v, 11, n);
        }
    }
 
}
 
cs