본문 바로가기
PS/BOJ

백준 14413 자바 - Poklon (BOJ 14413 JAVA)

by Nahwasa 2021. 12. 11.

문제 : boj14413

 

 

1.

  update가 없는 쿼리 문제로, 오프라인 쿼리로 처리해도 된다. 또한 각 쿼리마다 범위가 계속 변경되는 와중에 원하는 답을 구하는 쿼리 문제이다. 따라서 mo's algorithm을 사용할 경우 효율적으로 풀 수 있다. 모스 알고리즘은 쿼리의 수가 n이고 각 쿼리 qi = (a, b)라 할 때(문제에선 l, r이었음. l이 헷갈려서 a, b로 변경하여 설명함), 모든 쿼리를 (a/sqrt(n), b) 순서로 정렬하게 된다(a/sqrt(n)은 소수점 버리고 정수로). 이 경우 각 쿼리를 정렬 후 수행 시 모든 경우에 대해 a와 b의 변화량이 최대 N*sqrt(N) 정도로 고정되고, 효율적으로 직전에 구해둔 쿼리의 데이터 중 중복되는 부분을 재사용할 수 있다. 예를들어 q1=(1,5), q2=(2,4)일 경우 q1에서 1~5까지 구해둔 결과를 가지고 q2를 돌릴 땐 2~4를 모두 보는게 아니라, q1의 결과에서 1과 5를 제외하면 된다.

 

 

2.

  이 문제에서 구하려는건 qi=(a,b)에서 a~b 구간 사이에 정확히 두 번만 나타난 서로다른 숫자의 개수이다. 이 때 매번 모든 쿼리의 a~b를 살펴보면 비효율적이므로 모스 알고리즘을 사용하려 하는데, 그러려면 나타났던 숫자들에 대한 카운팅이 필요하다. 예를들어 다음과 같은 입력을 살펴보자.

5 3
1 2 1 2 1
1 3
1 5
3 5

'1'을 적용해서 정렬해도 순서는 동일하게 (1, 3) -> (1, 5) -> (3, 5) 이다.

 

2.1 시작은 다음과 같다.

 

2.2 기존에 데이터가 없고, 처음 볼 쿼리는 a=1, b=3 이므로 우선 첫번째 수(idx 1의 num)를 보겠다. 위의 배열에서 해당 수를 보고, 아래 배열에 해당 수를 카운팅 하는 것이다.

 

2.3 다음으로 idx2를 본다.

 

2.4 마지막으로 idx3을 확인한다. 이와 같이 구간을 더해나가는 중에 해당수의 cnt가 2가 됬다면 답으로 판단할 수 있다. 따라서 (1,3)쿼리에 대한 답은 1개가 된다.

 

2.5 다음으로 두번째 쿼리인 (1, 5)를 보자. (1, 3) 쿼리에서 idx4, idx5만 추가로 더해넣어주면 된다. idx4를 보면 2가 추가되어 cnt가 2가 됬으므로 현재까지 답은 2개가 된다.

 

2.6 마지막으로 idx5를 보자. 그럼 num1에 대한 cnt가 3이 되면서 답이 될 수 없다. 따라서 (1, 5) 쿼리에 대한 답은 1개가 된다. 이와 같이 구간을 더해나가는 중에 cnt가 3이 되면 답에서 빠진다.

 

2.7 이번엔 마지막 쿼리인 (3, 5)를 보자. 이전 쿼리인 (1, 5)에서 idx1과 idx2를 빼내면 된다. 우선 idx1을 보자. 이번엔 빼내는 것이므로 cnt가 감소한다. 마찬가지로, 구간을 빼나가는 중에 cnt가 2가 되면 답이 될 수 있다.

 

2.8 idx2를 보자. 이번에도 마찬가지로 구간을 빼나가는 중에 cnt가 1이 되면 답에서 제외된다.

 

2.9 '1'의 과정을 통해 효과적으로 쿼리의 순서를 변경하고 + '2'의 과정을 통해 답을 구해나가면 효율적으로 이 문제를 풀 수 있다.

 

 

3.

  그런데 '2'의 과정을 진행하려면 문제가 하나 있다. 각 숫자가 최대 10억까지 가능하다는 점인데, '2'를 설명할 때 그렸던 카운팅을 위한 배열에 num이 10억까지 있어야 한다는 것이다. 그리고 이 문제의 메모리 제한은 512mb 이므로 메모리초과로 그렇게 큰 배열은 불가능하다.

 

  이 문제에서 N은 최대 500,000 이고, 실제 값을 출력하는 것이 아니라 해당 구간에서 2개만 있는 수의 개수만 출력하면 되는 문제이므로 실제 값과 상관없이 모든 값은 50만 이하로 표현 가능하다(N이 50만이고 모든 수가 달랐더라도, 최대 50만개의 다른 수 이므로). 따라서 값 압축을 통해 10억의 수를 50만 이하로 압축시키면 메모리초과 문제도 해결할 수 있다.

 

 

코드 : github

import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;

class Query implements Comparable<Query> {
    int a, b, idx;
    static int sqrtN;

    public Query(int a, int b, int idx) {
        this.a = a;
        this.b = b;
        this.idx = idx;
    }

    @Override
    public int compareTo(Query o) {
        int sqrtA1 = this.a / sqrtN;
        int sqrtA2 = o.a / sqrtN;
        if (sqrtA1 == sqrtA2) return this.b - o.b;
        return sqrtA1 - sqrtA2;
    }
}

public class Main extends FastInput {
    private int answerCnt = 0;
    private int[] cnt;

    private void pushNum(int num) {
        switch (++cnt[num]) {
            case 2 : answerCnt++; break;
            case 3 : answerCnt--; break;
        }
    }

    private void popNum(int num) {
        switch (--cnt[num]) {
            case 1 : answerCnt--; break;
            case 2 : answerCnt++; break;
        }
    }

    private int[] findQueryResult(int[] arr, int q, ArrayList<Query> queries) {
        int[] queryResult = new int[q];

        Query first = queries.get(0);
        for (int i = first.a; i <= first.b; i++) {
            pushNum(arr[i]);
        }
        queryResult[first.idx] = answerCnt;

        for (int queryIdx = 1; queryIdx < q; queryIdx++) {
            Query cur = queries.get(queryIdx);
            Query bf = queries.get(queryIdx-1);

            for (int i = bf.a; i < cur.a; i++) popNum(arr[i]);
            for (int i = cur.a; i < bf.a; i++) pushNum(arr[i]);
            for (int i = bf.b+1; i <= cur.b; i++) pushNum(arr[i]);
            for (int i = cur.b+1; i <= bf.b; i++) popNum(arr[i]);

            queryResult[cur.idx] = answerCnt;
        }
        return queryResult;
    }

    private void solution() throws Exception {
        int n = nextInt();
        Query.sqrtN = (int) Math.sqrt(n);
        int q = nextInt();

        HashMap<Integer, Integer> numCompression = new HashMap<>();
        int num = 1;
        int[] arr = new int[n+1];
        for (int i = 1; i <= n; i++) {
            int cur = nextInt();
            if (!numCompression.containsKey(cur))
                numCompression.put(cur, num++);
            arr[i] = numCompression.get(cur);
        }
        cnt = new int[num];

        ArrayList<Query> queries = new ArrayList<>();
        for (int i = 0; i < q; i++) {
            queries.add(new Query(nextInt(), nextInt(), i));
        }
        Collections.sort(queries);

        int[] queryResult = findQueryResult(arr, q, queries);
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < q; i++) {
            sb.append(queryResult[i]).append('\n');
        }
        System.out.print(sb);
    }

    public static void main(String[] args) throws Exception {
        initFI();
        new Main().solution();
    }
}

class FastInput {
    private static final int DEFAULT_BUFFER_SIZE = 1 << 16;
    private static DataInputStream inputStream;
    private static byte[] buffer;
    private static int curIdx, maxIdx;

    protected static void initFI() {
        inputStream = new DataInputStream(System.in);
        buffer = new byte[DEFAULT_BUFFER_SIZE];
        curIdx = maxIdx = 0;
    }

    protected static int nextInt() throws IOException {
        int ret = 0;
        byte c = read();
        while (c <= ' ') c = read();
        do {
            ret = ret * 10 + c - '0';
        } while ((c = read()) >= '0' && c <= '9');
        return ret;
    }

    private static byte read() throws IOException {
        if (curIdx == maxIdx) {
            maxIdx = inputStream.read(buffer, curIdx = 0, DEFAULT_BUFFER_SIZE);
            if (maxIdx == -1) buffer[0] = -1;
        }
        return buffer[curIdx++];
    }
}

댓글