코딩/알고리즘 & 자료구조

Mo's algorithm (모스 알고리즘)

stonejjun 2020. 5. 25. 19:48

sqrt, 루트계의 스페셜리스트 알고리즘인 Mo's algorithm을 소개하려고 한다. 어떤 문제를 해결할 때 기본적으로 루트풀이는 친숙하지 않아서 떠올리기 힘들다. 하지만 그 와중에 Mo's algorithm은 굉장히 신박한 아이디어로 만든 시간복잡도이기 때문에 이 알고리즘은 정말 '알고 있어야' 사용이 가능하고 문제를 풀 수 있을 것이다.

Mo's algorithm 이란?

Mo's algorithm은 업데이트가 없이 오프라인으로 구간 쿼리가 많이 주어질 때, 그 구간쿼리들을 효율적으로 처리하는 알고리즘이다. 기본적으로 오프라인 쿼리여야 되고, 처리하는 쿼리들의 순서를 잘 조정함으로서 더 효율적으로 처리를 할 수 있게 됩니다.

Mo's algorithm의 작동 방식

1. 어떤 쿼리 (구간) 에 대한 답을 구한다.
2. 1번의 구간에서 앞부분의 구간을 더하거나 뺀다. 뒷부분의 구간을 더하거나 뺀다.
3. 2번 과정을 통해 다음 구할 쿼리의 구간에 도달한다.
4. 그 과정에서 다음 쿼리의 값을 계산한다.

예를 들어서 단순하게 구간합을 구하고 싶다고 하자.
배열은 (1,2,3,4,5,6,7,8)인 크기 8의 배열이고 1~8의 구간합, 2~6의 구간 합, 3~7의 구간합을 차례대로 구하고 싶다고 가정하자.

맨 처음 구간 합은 그냥 $O(N)$에 구할 수 있다.  그 전 구간의 크기를 0이라고 생각하고 구간 전체를 하나식 더해가능 과정에서 구할 수 있다.

구간 (1~8)에서 구간 (2~6)으로 오기 위해서는 앞에서 1을 제거하고, 뒷부분에서 7과 8을 제거하면 된다. 그냥 앞에서부터, 뒤에서 부터 순서대로 제거하면서 숫자를 빼주면 36-1-7-8이 되면서 20이 나오게 되고 20은 현재 원하고자 하는 구간의 합, 즉 2에서 6까지의 구간 합이다. 

위의 과정에서 구한 구간 (2~6)에서 구간 (3~7)로 넘어가려면 앞 쪽의 2를 빼고 뒷 쪽에 7을 붙이면 된다. 빼고 붙이는 과정에서 -2+7이 일어나기 때문에 3~7의 구간합은 25로 구해지게 된다.

시간 복잡도 분석

한 칸씩 움직일 때마다 $O(1)$이 걸리는 작업이라면 최악의 경우는 인접한 두 구역간에서 약 $O(N)$번을 움직여야 되고 한 개의 쿼리를 처리하는데 $O(N)$이 걸리게 된다. 따라서 구간들을 최대한 잘 조정해서 구간간의 변화에 따른 이동을 최소화 해야한다. 
이를 쿼리를 정렬함으로써 해결 할 수 있고, 그 쿼리의 구간 (l,r)에 대해서  정렬방법은 다음과 같다.
1. $l/ \sqrt{N}$에 대해서 먼저 정렬한다.
2. 1의 값이 같은 구간의 쿼리에 대해서 r에 대해서 정렬한다.

위와 같이 정렬을 한다고 하자. 일단 1번 정렬에 의해서 쿼리들이 $ \sqrt{N}$개의 그룹으로 나눠진다. 각 그룹에 대해서 r은 계속 증가하기 때문에 최대 $O(N)$만큼 움직인다. 또한 각 그룹에 대해서 l의 최대 차이는 $ \sqrt{N}$이기 때문에 모든 쿼리에 대해서 다음 쿼리까지 l의 이동은 최대 $ \sqrt{N}$회 이루어진다. 따라서 구간간의 이동에서 l과 r 모두 $O( \sqrt{N})$ 번의 이동을 하게 되고 따라서 총 시간복잡도 $O(N \sqrt{N})$만에 문제를 해결할 수 있다.

코드

아래의 코드는 BOJ 13547 수열과 쿼리 5를 푸는 코드이다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
 
struct q{
    ll idx,s,e;
};
q qer[1010101];
ll ans[1010101];
ll cnt[1010101];
ll arr[1010101];
ll sqn;
ll now;
ll lf,rf;
 
bool sf(q a,q b)
{
    if(a.s/sqn!=b.s/sqn) return a.s<b.s;
    return a.e<b.e;
}
 
void mi(ll s, ll e)
{
    ll i;
    for(i=s;i<=e;i++)
    {
        cnt[arr[i]]--;
        if(cnt[arr[i]]==0) now--;
    }
}
 
void pl(ll s, ll e)
{
    ll i;
    for(i=s;i<=e;i++)
    {
        if(cnt[arr[i]]==0) now++;
        cnt[arr[i]]++;
    }
}
 
int main()
{
    ll i,j,k,l,m,n,Q;
    scanf("%lld",&n);
    sqn=sqrt(n);
    for(i=1;i<=n;i++)
        scanf("%lld",&arr[i]);
 
    scanf("%lld",&Q);
    for(i=1;i<=Q;i++)
    {
        scanf("%lld %lld",&qer[i].s,&qer[i].e);
        qer[i].idx=i;
    }
 
    sort(qer+1,qer+1+Q,sf);
 
 
    for(i=qer[1].s;i<=qer[1].e;i++)
    {
        if(cnt[arr[i]]==0) now++;
        cnt[arr[i]]++;
    }
    ans[qer[1].idx]=now;
    ll lf=qer[1].s;
    ll rf=qer[1].e;
    //for(j=1;j<=5;j++)
    //        printf("%lld ",cnt[j]);
    //    printf("\n");
 
    for(i=2;i<=Q;i++)
    {    
        if(qer[i].s<lf) pl(qer[i].s,lf-1);
        if(qer[i].e>rf) pl(rf+1,qer[i].e);
        if(qer[i].s>lf) mi(lf,qer[i].s-1);
        if(qer[i].e<rf) mi(qer[i].e+1,rf);
 
        lf=qer[i].s;
        rf=qer[i].e;
        ans[qer[i].idx]=now;
    }
 
    for(i=1;i<=Q;i++)
        printf("%lld\n",ans[i]);
}