본문 바로가기
PS/BOJ

백준 14897 자바 - 서로 다른 수와 쿼리 1 (BOJ 14897 JAVA)

by Nahwasa 2021. 11. 19.

문제 : https://www.acmicpc.net/problem/14897

 

  확실히 나한테 아직 플래수준은 어려운것같다. 아예 모르는 알고리즘 지식을 요구하는 문제가 아닌이상 꾸역꾸역 풀긴 하는 편인데 사실 이런게 CP에서 나왔다고 치면 시간내에 절대 못푼다. 못해도 1시간 이상은 걸리는 것 같다. 오답률도 높은편 ㅠ 뭐 꾸준히 하다보면 언젠가 플래수준도 쉽게 풀겠지!

 

생각 과정

풀이만 보려면 아래로 쭉 내려서 '풀이 정리'로 바로 가시면 됩니다.

 

1. 우선 문제에서 바로바로 생각나는걸 정리했다.

- 업데이트가 없는 쿼리 이므로, 쿼리 순서를 마음대로 처리해도 된다. (오프라인 쿼리)

 

- N이 최대 100만, 배열에 포함된 수는 최대 10억이다. 이 때, 문제에서 묻는건 서로 다른 수의 개수이므로 배열에 포함된 수를 마음대로 변경해도 된다. N개가 모두 다른 수로 들어와도 압축해서 100만 이하로 만들 수 있다.

 

- 이 문제의 경우 당연히 매번 [L, R] 구간 내의 모든 수의 갯수를 확인해보면 시간초과다. 따라서 이전 쿼리에서 확인했던 값들 중 이번에 확인중인 쿼리와 겹치는 구간을 어떻게 효율적으로 재사용할지가 관건이다. 

 

- 처음에 떠오른 생각은 ORDER BY L ASC, R ASC 순서로 쿼리를 정렬하는 것이었다(수학 기호로 어떻게 적어야하는지 몰라서 쿼리문으로 대체함 ㅋㅋ).

 

  예를 들어 쿼리가 [1,3], [1,10], [2,8], [2,5] 이런식으로 들어온다면 정렬해서 [1,3], [1,10], [2,5], [2,8]과 같은 순서로 정렬한다. S가 1인 [1,3], [1,10]을 A라 하고, S가 2인 [2,5], [2,8]을 B라 하자. 우선 A에 대해서는 E가 변경된 구간만큼만 추가로 확인하면 된다. 그다음 A를 보다가 B로 넘어갈 경우, A의 마지막 E와 비교해서 재사용할 부분 재사용하는 식으로 생각했었다. 

 

 

2. '1'까지의 내용을 바탕으로 일단 무작정 코드를 짜봤다. 아래와 같다.

import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;

class Query implements Comparable<Query>{
    int s, e, order;   // s means l, e means r.

    public Query(int s, int e, int order) {
        this.s = s;
        this.e = e;
        this.order = order;
    }

    @Override
    public int compareTo(Query another) {
        if (this.s == another.s) return this.e-another.e;
        return this.s-another.s;
    }
}

public class Main extends FastInput {
    private void solution() throws Exception {
        int n = nextInt();

        int[] arr = new int[n+1];
        HashMap<Integer, Integer> compression = new HashMap<>();
        int compNum = 0;
        for (int i = 1; i <= n; i++) {
            int a = nextInt();
            if (!compression.containsKey(a)) {
                compression.put(a, compNum++);
            }
            arr[i] = compression.get(a);
        }
        compression = null;

        int q = nextInt();
        ArrayList<Query> queries = new ArrayList<>(q);
        for (int i = 0; i < q; i++) {
            queries.add(new Query(nextInt(), nextInt(), i));
        }

        Collections.sort(queries);

        int answer = 0;
        int[] answerArr = new int[q];
        int[] cnt = new int[compNum];   // maximum 100,000

        // init first
        for (int i = queries.get(0).s; i <= queries.get(0).e; i++) {
            if (++cnt[arr[i]] == 1) answer++;
        }
        answerArr[queries.get(0).order] = answer;

        // solve
        for (int idx = 1; idx < queries.size(); idx++) {
            int cs = queries.get(idx).s;
            int ce = queries.get(idx).e;
            int bs = queries.get(idx-1).s;
            int be = queries.get(idx-1).e;

            for (int i = bs; i < cs; i++) if (--cnt[arr[i]] == 0) answer--;
            for (int i = cs; i < bs; i++) if (++cnt[arr[i]] == 1) answer++;
            for (int i = be+1; i <= ce; i++) if (++cnt[arr[i]] == 1) answer++;
            for (int i = ce+1; i <= be; i++) if (--cnt[arr[i]] == 0) answer--;

            answerArr[queries.get(idx).order] = answer;
        }

        StringBuilder output = new StringBuilder();
        for (int i = 0; i < q; i++) {
            output.append(answerArr[i]).append('\n');
        }
        System.out.print(output);
    }

    public static void main(String[] args) throws Exception {
        initFI();
        new Main().solution();
    }
}

class FastInput {
    private static final int DEFAULT_BUFFER_SIZE = 1 << 16;
    private static DataInputStream inputStream;
    private static byte[] buffer;
    private static int curIdx, maxIdx;

    protected static void initFI() {
        inputStream = new DataInputStream(System.in);
        buffer = new byte[DEFAULT_BUFFER_SIZE];
        curIdx = maxIdx = 0;
    }

    protected static int nextInt() throws IOException {
        int ret = 0;
        byte c = read();
        while (c <= ' ') c = read();
        do {
            ret = ret * 10 + c - '0';
        } while ((c = read()) >= '0' && c <= '9');
        return ret;
    }

    private static byte read() throws IOException {
        if (curIdx == maxIdx) {
            maxIdx = inputStream.read(buffer, curIdx = 0, DEFAULT_BUFFER_SIZE);
            if (maxIdx == -1) buffer[0] = -1;
        }
        return buffer[curIdx++];
    }
}

 

 

3. 위의 과정에서 시간초과가 났고, 세그먼트 트리쪽으로도 생각해봤는데 세그쪽은 별로 친하지 않기도하고 딱히 생각이 안나서 결국 서로 겹치는 구간을 최대한 많이하도록 순서대로 정렬할 방법을 찾아야 했다.

 

  여러 방식으로 해봤는데 아무래도 수학적 사고가 부족하다보니 전부 시간초과 났다 ㅠㅠ. 메모리 초과는 혹시라도 압축 시간이라도 줄여볼까 해서 압축 안하니 바로 메모리 초과났다ㅋㅋㅋ

 

 

4. 보통은 더이상 생각 안나면 그냥 덮어두고 나중에 공부하다보면 언젠가 풀겠지 하고 넘어가는 편인데, 이런 문제를 이미 꽤 넘어갔었다. 이번에 검색해보니 Mo's algorithm라는 녀석이 딱 이 경우에 맞는 알고리즘이었다. 이런식으로 좌우로 겹치면서 중복되는 부분을 최대한으로 유지해도록 하는(즉, 이전 L과 이번 L의 차이, 이전 R과 이번 R의 차이를 최소화) 알고리즘이었다. 아예 처음 들어본 알고리즘이었다 ㅋㅋ

 

  다만 '2'의 과정까지가 사실상 해당 알고리즘 동작 방식을 이미 사용한 상태였다(정신승리). 정렬 방식에 대한 아이디어만 참고해서 푸니 통과됬다. 모스 알고리즘이 정렬 방식을 ORDER BY L/sqrt(N), R로 한 후, 위에 적어뒀던 과정을 진행하는 것이었다. 수학 신기해.. 뭐 덕분에 처음 들어본 알고리즘인데 금방 익힐 수 있게 되서 완전 이득. 다만 알고리즘 찾아보고 푼거라 좀 진 기분이긴 하다. 주말에 모스 알고리즘 관련된 문제를 좀 더 풀어보면 앞으로 잘 사용할 것 같다.

 

 

 

풀이 정리

A. 들어온 A를 배열의 크기 N 이하로 압축한다. (코드 32~41 line)

 

B. 쿼리를 ORDER BY L/sqrt(N) ASC, R ASC 로 정렬한다(Mo's algorithm + offline query). 이 때 쿼리가 처음 들어왔던 순서는 알고 있어야 순서대로 출력할 수 있다.

 

C. 정렬된 쿼리에서, 이전 쿼리와 이번 쿼리의 L과 R의 변경을 확인해서 겹치는 구간을 재사용하며 답을 찾아준다. (코드 62~74 line)

 

D. 기존 쿼리의 순서대로 답을 출력한다.

 

 

코드 : https://github.com/NaHwaSa/BOJ_BaekjunOnlineJudge/blob/master/14800/BOJ_14897.java

import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;

class Query implements Comparable<Query>{
    int s, e, order;   // s means l, e means r.
    static int sqrtN;

    public Query(int s, int e, int order) {
        this.s = s;
        this.e = e;
        this.order = order;
    }

    @Override
    public int compareTo(Query another) {
        int thisS = this.s / sqrtN;
        int anotherS = another.s / sqrtN;
        if (thisS == anotherS) return this.e-another.e;
        return thisS - anotherS;
    }
}

public class Main extends FastInput {
    private void solution() throws Exception {
        int n = nextInt();
        Query.sqrtN = (int) Math.sqrt(n);

        int[] arr = new int[n+1];
        HashMap<Integer, Integer> compression = new HashMap<>();
        int compNum = 0;
        for (int i = 1; i <= n; i++) {
            int a = nextInt();
            if (!compression.containsKey(a)) {
                compression.put(a, compNum++);
            }
            arr[i] = compression.get(a);
        }
        compression = null;

        int q = nextInt();
        ArrayList<Query> queries = new ArrayList<>(q);
        for (int i = 0; i < q; i++) {
            queries.add(new Query(nextInt(), nextInt(), i));
        }

        Collections.sort(queries);

        int answer = 0;
        int[] answerArr = new int[q];
        int[] cnt = new int[compNum];   // maximum 100,000

        // init first
        for (int i = queries.get(0).s; i <= queries.get(0).e; i++) {
            if (++cnt[arr[i]] == 1) answer++;
        }
        answerArr[queries.get(0).order] = answer;

        // solve
        for (int idx = 1; idx < queries.size(); idx++) {
            int cs = queries.get(idx).s;
            int ce = queries.get(idx).e;
            int bs = queries.get(idx-1).s;
            int be = queries.get(idx-1).e;

            for (int i = bs; i < cs; i++) if (--cnt[arr[i]] == 0) answer--;
            for (int i = cs; i < bs; i++) if (++cnt[arr[i]] == 1) answer++;
            for (int i = be+1; i <= ce; i++) if (++cnt[arr[i]] == 1) answer++;
            for (int i = ce+1; i <= be; i++) if (--cnt[arr[i]] == 0) answer--;

            answerArr[queries.get(idx).order] = answer;
        }

        StringBuilder output = new StringBuilder();
        for (int i = 0; i < q; i++) {
            output.append(answerArr[i]).append('\n');
        }
        System.out.print(output);
    }

    public static void main(String[] args) throws Exception {
        initFI();
        new Main().solution();
    }
}

class FastInput {
    private static final int DEFAULT_BUFFER_SIZE = 1 << 16;
    private static DataInputStream inputStream;
    private static byte[] buffer;
    private static int curIdx, maxIdx;

    protected static void initFI() {
        inputStream = new DataInputStream(System.in);
        buffer = new byte[DEFAULT_BUFFER_SIZE];
        curIdx = maxIdx = 0;
    }

    protected static int nextInt() throws IOException {
        int ret = 0;
        byte c = read();
        while (c <= ' ') c = read();
        do {
            ret = ret * 10 + c - '0';
        } while ((c = read()) >= '0' && c <= '9');
        return ret;
    }

    private static byte read() throws IOException {
        if (curIdx == maxIdx) {
            maxIdx = inputStream.read(buffer, curIdx = 0, DEFAULT_BUFFER_SIZE);
            if (maxIdx == -1) buffer[0] = -1;
        }
        return buffer[curIdx++];
    }
}

댓글