본문 바로가기
PS/BOJ

[자바] 백준 13544 - 수열과 쿼리 3 (java)

by Nahwasa 2022. 9. 17.

 문제 : boj13544


 

필요 알고리즘 개념

  • 정해 : 머지 소트 트리 + 세그먼트 트리
    • 정해는 세그먼트 트리를 응용한 머지 소트 트리로 보인다.
  • 내 경우 : 펜윅 트리
    • 머지 소트 트리를 펜윅 트리로 적용시켜서도 풀 수 있다.

※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

 


 

풀이

  일반적으로는 대부분 세그먼트 트리에 각 노드를 배열로 처리한 머지 소트 트리를 사용해 풀었다. 내 경우엔 아직 세그먼트 트리를 정식으로 공부하지 않았으므로, 내가 알고 있는 선에서 처리해보려고 펜윅 트리로 응용해서 적용해봤다. 따라서 정해가 아니고, 내 방식대로 푼 것이므로 머지 소트 트리로 풀려면 다른 풀이를 보기 바란다.

 

  우선 이 문제는 매번 이전 결과값을 xor해서 쿼리를 생성하므로, 오프라인 쿼리로 처리가 불가하다. 만약 오프라인 쿼리라면 단순히 mo's algorithm에 세그먼트나 sqrt decomposition을 적용하면 풀 수 있다. 하지만 온라인 쿼리로 처리해야 하므로 알고리즘이 어느정도 강제되게 된다.

 

  머지 소트 트리는 세그먼트 트리의 각 노드를 정렬된 배열로 처리한다. 그렇다면, 펜윅 트리의 각 노드를 정렬된 배열로 처리하더라도 적어도 이 문제는 문제없이 풀 수 있다고 생각했다(머지 소트 트리를 공부해보면 알 수 있겠지만, 머지 소트 트리로 처리 가능한걸 항상 펜윅 트리에 정렬된 배열을 추가한다고 풀 수 있는건 아니다). 다만 펜윅 트리는 [1, x] 구간의 k보다 큰 원소의 개수를 알 수 있는게 문제인데, [a, b] 구간의 k보다 큰 원소의 개수는 결국 [1, b] 구간의 k보다 큰 원소의 수에서 [1, a-1] 구간의 k보다 큰 원소의 수를 빼는 것으로 구할 수 있다.

 

  펜윅 트리에 대해 잘 모른다면 이 문제는 point update, range query 이므로 '적어둔 펜윅 트리 글'에서 '기본' 부분을 읽어보자. 다만 펜윅 트리의 값이 위 글에서는 int나 long 처럼 단일값이었다면, 이 문제에서는 배열 형태면 된다.

 

1. point update를 해당 펜윅트리의 각 노드에 넣어준다.

private void update(int idx, int val) {
    while (idx <= n) {
        bit[idx].add(val);
        idx += idx&-idx;
    }
}

 

2. 중간중간 query가 진행되는 방식이 아니고, update가 모두 진행된 후 query가 진행되므로 update가 끝난 후 펜윅 트리의 각 노드를 정렬해줘도 문제 없다.

for (int i = 1; i <= n; i++) {
    Collections.sort(bit[i]);
}

 

3. 이후 k보다 큰 값의 개수는 정렬된 펜윅 트리의 각 노드에서 이분 탐색을 통해 알아낼 수 있다.

private int getCntOfOver(ArrayList<Integer> list, int k) {
    int res = Collections.binarySearch(list, k+1);
    if (res < 0) {
        res += 1;
        res = -res;
    }
    return list.size() - res;
}

 

이게 가능한 이유는, 기본 펜윅트리에서 query 시 구간이 겹치는 부분이 없기 때문이다. 각 구간에서 k보다 큰 수의 합으,ㄹ 알아낸다면 전체 구간의 k보다 큰 수와 동일하다.

 

 


 

코드 : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.StringTokenizer;

public class Main {
    ArrayList<Integer>[] bit;
    int n;

    private void update(int idx, int val) {
        while (idx <= n) {
            bit[idx].add(val);
            idx += idx&-idx;
        }
    }
    
    private int getCntOfOver(ArrayList<Integer> list, int k) {
        int res = Collections.binarySearch(list, k+1);
        if (res < 0) {
            res += 1;
            res = -res;
        }
        return list.size() - res;
    }

    private int prefixSumOfCnt(int idx, int k) {
        int cnt = 0;
        while (idx > 0) {
            cnt += getCntOfOver(bit[idx], k);
            idx -= idx&-idx;
        }
        return cnt;
    }

    private int query(int i, int j, int k) {
        return prefixSumOfCnt(j, k) - prefixSumOfCnt(i-1, k);
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine());
        bit = new ArrayList[n+1];
        for (int i = 1; i <= n; i++) bit[i] = new ArrayList<>();
        StringTokenizer st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) {
            update(i, Integer.parseInt(st.nextToken()));
        }
        for (int i = 1; i <= n; i++) {
            Collections.sort(bit[i]);
        }

        int m = Integer.parseInt(br.readLine());
        int lastAnswer = 0;
        StringBuilder sb = new StringBuilder();
        while (m-->0) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            int c = Integer.parseInt(st.nextToken());
            int i = a^lastAnswer;
            int j = b^lastAnswer;
            int k = c^lastAnswer;
            lastAnswer = (int)query(i, j, k);
            sb.append(lastAnswer).append('\n');
        }
        System.out.print(sb);
    }
    public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

댓글