카테고리 없음

BOJ 17974 - Same Color

stonejjun 2020. 3. 13. 16:23

797이 들어간 문제로 ICPC 2019 Seoul G번인 이 문제를 선택했다.

문제 태그 : https://www.acmicpc.net/problem/17974

문제 소개

점 N개의 좌표와 색깔이 주어진다. 우리는 이 점들중 1개이상, 몇 개의 점들을 뽑아야 한다. 뽑힌 점들을 A그룹이라고 하자. 뽑히지 않은 점들을 B그룹이라고 하자. 이때 모든 B 그룹의 점들은 가장 가까운 A그룹의 점들에 색깔이 같은 점이 존재해야 한다. 이때 뽑을 수 있는 A그룹의 최소 크기를 구하시오. 

문제 풀이 

일단 수직선상에 점을 흩뿌려놓고 생각을 해보자. 가장 먼저 연속된 같은 색깔의 점들을 그룹으로 묶어야 한다. 이때 관찰을 할 수 있다. 한 그룹내에서 최소 1개이상의 점은 뽑아야한다. 한 그룹내에서 최대 2개의 점까지만 뽑으면 된다. 

일단 처음에 greedy 하게 생각을 했다. 하지만 완벽하게 '가장 좋은 상황' 이라는 것이 정의가 되지 않는다. 한 그룹내에서 가장 뒤의 점을 선택해야 좋을 때도 있고, 중간쯤의 점을 선택해야 좋을 때도 있었다. 따라서 DP로 생각을 바꾸었다. 

2차원이 성립하지 않기 때문에 DP[i]= 1부터 i까지만 있을때의 문제에 대한 답으로 설정했다. 그후 dp값들의 연관성을 얻어내기 위해 연속된 그룹의 특징을 관찰하였다. 어떤 그룹에서 마지막으로 선택되는 점이 정해지면 다음 그룹에서 처음으로 선택될 수 있는 점들이 정해진다. 이는 연속된 점들의 구간으로 정해질 것이다. 
따라서 DP[i]= i를 뽑으면서 1~i까지만 있을때의 문제에 대한 답으로 설정을 바꾸었다.

DP[i]에 의해서 다음 그룹과 비교해 가면서 가능한 구간은 min(기존,dp[i]+1)로 업데이트 해주면 된다. 이때 한 그룹내에서 최대 2개의 점만 뽑으면 해결이 되므로, 만약 가능한 구간이 존재 했다면 그 이후부터 다음 그룹의 끝까지는 min(기존, dp[i]+2)로 업데이트 해주면 된다. 구간에 min을 취하는 연산은 segment tree beats라는 것을 이용해 처리해 줄 수 있다.

코드

사실 구하는 값이 범위의 합, 범위 max등이 아니기 때문에 segment tree 만 써도 충분히 해결가능한 문제였다. 하지만 segment tree beats를 써봤었기 때문에, 그냥 사용하였다. 덕분에 쓸모 없는 부분의 코드가 많다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include<bits/stdc++.h>
using namespace std;
typedef long long int ll;
typedef pair<ll,ll> pii;
#define ff first
#define ss second
#define eb emplace_back
#define pb push_back
 
struct node{
    ll mx,mx2,cmx,val;
};
 
struct poi{
    ll dis,col,pgp;
};
 
bool sf(poi a,poi b){
    return a.dis<b.dis;
}
 
poi arr[1010101];
ll brr[1010101];
node tree[5050505];
 
node mg(node a,node b){
    if(a.mx==b.mx) return{a.mx,max(a.mx2,b.mx2),a.cmx+b.cmx,a.val+b.val};
    if(a.mx>b.mx) swap(a,b);
    return {b.mx,max(a.mx,b.mx2),b.cmx,a.val+b.val};
}
 
node init(ll s,ll e,ll nod){
    if(s==e) return tree[nod]={brr[s],-1,1,brr[s]};
    return tree[nod]=mg(init(s,(s+e)/2,nod*2),init((s+e)/2+1,e,nod*2+1));
}
 
void laz(ll s,ll e,ll nod){
    if(s==e) return;
    for(auto i:{nod*2,nod*2+1}){
        if(tree[nod].mx<tree[i].mx){
            tree[i].val-=tree[i].cmx*(tree[i].mx-tree[nod].mx);
            tree[i].mx=tree[nod].mx;
        }
    }
}
 
void upt(ll l,ll r,ll v,ll s,ll e,ll nod){
    laz(s,e,nod);
    if(r<s||e<l||tree[nod].mx<=v) return ;
    if(l<=s&&e<=r&&tree[nod].mx2<v){
        tree[nod].val-=tree[nod].cmx*(tree[nod].mx-v);
        tree[nod].mx=v;
        laz(s,e,nod);
        return ;
    } 
    upt(l,r,v,s,(s+e)/2,nod*2);
    upt(l,r,v,(s+e)/2+1,e,nod*2+1);
    tree[nod]=mg(tree[nod*2],tree[nod*2+1]);
}
 
ll sol2(ll l,ll r,ll s,ll e,ll nod){
    laz(s,e,nod);
    if(r<s||e<l) return 0;
    if(l<=s&&e<=r) return tree[nod].val;
    return sol2(l,r,s,(s+e)/2,nod*2)+sol2(l,r,(s+e)/2+1,e,nod*2+1);
}
 
vector<ll> gp[1010101];
ll gs[1010101];
ll ge[1010101];
ll dp[1010101];
ll pgs[1010101];
 
int main(){
    ll i,j,k,l,m,n,a,b,c,q;
    scanf("%lld %lld",&m,&n);
    for(i=1;i<=n;i++)
        scanf("%lld",&arr[i].dis);
    for(i=1;i<=n;i++)
        scanf("%lld",&arr[i].col);
    sort(arr+1,arr+1+n,sf);
    for(i=1;i<=n;i++){
        gs[i]=1e18;
        ge[i]=-1e18;
        brr[i]=1e18;
    }
 
    ll gpcnt=0;
    for(i=1;i<=n;i++){
        if(arr[i].col!=arr[i-1].col) gpcnt++;
        gp[gpcnt].pb(arr[i].dis);
        gs[gpcnt]=min(gs[gpcnt],arr[i].dis);
        ge[gpcnt]=max(ge[gpcnt],arr[i].dis);
        pgs[gpcnt]=max(pgs[gpcnt],i);
        if(gpcnt==1) brr[i]=1;
        arr[i].pgp=gpcnt;
    }
    init(1,n,1);
    ll ans=1e18;
    for(i=1;i<=n;i++){
        dp[i]=sol2(i,i,1,n,1);
        ll ngp=arr[i].pgp;
        if(ngp==gpcnt){
            ans=min(ans,dp[i]);
            continue;
        }
        ll idx1=lower_bound(gp[ngp+1].begin(), gp[ngp+1].end(),ge[ngp]-arr[i].dis+ge[ngp])-gp[ngp+1].begin()+1;
        ll idx2=upper_bound(gp[ngp+1].begin(), gp[ngp+1].end(),gs[ngp+1]-arr[i].dis+gs[ngp+1])-gp[ngp+1].begin();
 
        idx1+=pgs[ngp];
        idx2+=pgs[ngp];
        if(idx2<idx1) continue;
        upt(idx1,idx2,dp[i]+1,1,n,1);
        upt(idx2+1,pgs[ngp+1],dp[i]+2,1,n,1);
    }
    printf("%lld",ans);
}