본문 바로가기
PS/BOJ

[자바] 백준 17353(펜윅트리) - 하늘에서 떨어지는 1, 2, ..., R-L+1개의 별 (java)

by Nahwasa 2022. 8. 23.

 문제 : boj17353


 

필요 알고리즘 개념

  • 세그먼트 트리 lazy propagation 또는 펜윅 트리 개념
    • 세그먼트 트리를 통한 lazy propagation 혹은 range update 펜윅 트리를 알고 있어야 풀 수 있다. 또는 기본 펜윅트리로도 풀 수 있다.

※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

 


 

풀이 1 - 펜윅 트리 range update + range query를 통한 풀이

  우선 내가 처음 생각한 풀이부터 작성한다. 펜윅 트리에 대한 개념이 필요하고, 이 풀이는 펜윅트리의 range update - range query 응용 방식을 사용했다. 이하 풀이를 이해하려면 펜윅 트리 글에서 '응용3'을 참고하자.

 

 

  기본적으로 펜윅트리로 lazy propagation을 구현하려면 range update가 동일한 수치여야 한다. 이 문제의 경우 range update가 등차수열 형태이므로 다음과 같이 동일한 수치로 변경이 필요하다. 원본 배열을 A라고 하고, B를 다음과 같이 정의하자.

B[X] = A[X] - A[X-1]

 

 

  그럼 우선 A[X]를 얻는 방법부터 생각해보면, B[1]+B[2]+...+B[X] = A[X]이다. 예를들어 B[1] + B[2] + B[3] = A[1]-0+A[2]-A[1]+A[3]-A[2] = A[3] 이다. 따라서 문제에서는 point query 이지만 B에 대해서는 range query로 변경된다.

 

 

  [L, R]에 대한 update시 B의 변화는 다음과 같이 생각해볼 수 있다. (B[X]' 은 업데이트 이후, B[X]는 업데이트 이전이다)

1. B[L]' = A[L]+1-A[L-1] = B[L]+1 이므로 1이 증가한다.

2. B[L+1] = A[L+1]+2-(A[L]+1) = B[L+1]+1 이므로 마찬가지로 1이 증가한다.

3. 이후 B[R-1]까진 '2'와 동일하게 1이 증가한다.

4. B[R]' = A[R]+R-L+1-(A[R-1]+R-L) = B[R]+1 로 B[R]'도 마찬가지로 1이 증가한다.

5. B[R+1]' = A[R+1]-(A[R]+R-L+1) = B[R+1]-R+L-1 이므로 여기선 '-R+L-1'이 증가한다.

 

 

  따라서 위에서 링크를 걸어둔 글에서 응용3을 B 배열에 대해 적용해보면 다음과 같다.

1. [L,R]에 대한 range update는 [L,R] 구간에 대해 B배열을 +1씩 해주고 [R+1, R+1] 구간에 대해 '-R+L-1'을 해주는 것과 동일하다. -> range update

2. A[X]는 B에 대한 prefix sum을 구하면 알 수 있다. -> range query

 

  위의 내용을 작성한 펜윅트리 글의 응용3로 구현해주면 된다. 

  

 

코드 (풀이 1) : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    long[] bit1, bit2;
    int n;

    private void update(int bitType, int idx, long diff) {
        long[] bit = bitType==1 ? bit1 : bit2;
        while (idx <= n) {
            bit[idx] += diff;
            idx += idx&-idx;
        }
    }

    private void rangeUpdate(int a, int b, long diff) {
        update(1, a, diff);
        update(1, b+1, -diff);
        update(2, a, diff * (a-1));
        update(2, b+1, -diff * b);
    }

    private long getBitValue(int bitType, int idx) {
        long[] bit = bitType==1 ? bit1 : bit2;
        long answer = 0;
        while (idx > 0) {
            answer += bit[idx];
            idx -= idx&-idx;
        }
        return answer;
    }

    private long prefixSum(int idx) {
        return getBitValue(1, idx) * idx - getBitValue(2, idx);
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine());
        bit1 = new long[n+1];
        bit2 = new long[n+1];
        StringTokenizer st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) {
            int cur = Integer.parseInt(st.nextToken());
            rangeUpdate(i, i, cur);
            rangeUpdate(i+1, i+1, -cur);
        }

        int q = Integer.parseInt(br.readLine());
        StringBuilder sb = new StringBuilder();
        while (q-->0) {
            st = new StringTokenizer(br.readLine());
            switch (Integer.parseInt(st.nextToken())) {
                case 1:
                    int l = Integer.parseInt(st.nextToken());
                    int r = Integer.parseInt(st.nextToken());
                    rangeUpdate(l, r, 1);
                    rangeUpdate(r+1, r+1, -r+l-1);
                    break;
                case 2:
                    int x = Integer.parseInt(st.nextToken());
                    sb.append(prefixSum(x)).append('\n');
                    break;
            }
        }
        System.out.print(sb);
    }

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

 

 


 

 

풀이 2 - 펜윅 트리 range update + point query를 통한 풀이

  풀이 1로 푼 이후, 다른 분들은 어떻게 풀었나 보다가 'iwantmoney'님의 멋진 코드를 해석하고 약간 개선해서(펜윅트리 4개 -> 2개 사용 이긴한데, 사실 개념은 완전히 동일하다.) 내가 작성했던 펜윅트리 글에서 '응용2'로 가능함을 확인하고 작성해본다. 이하 풀이를 이해하려면 펜윅 트리 글에서 '응용2' 부분을 참고하자.

 

  [L,R]에 대한 range update에 대해 [L,R] 사이의 X를 생각해보자(L<=X<=R). 가능한 모든 X 인덱스에 대해 A[X]+=X-L+1 이 이뤄질 것이다. 그럼 저기서 X+1과 -L을 따로 생각해보면, A[X]에 대해 X+1은 당연히 언제든 알 수 있는 값이다. 따라서 X에 대해 연산이 일어난 횟수만 가지고 있어도 된다. 반면에 -L은 X와 관계가 없으므로 고정값이다.

 

  그럼 A[X]에 연산이 일어난 횟수 cnt와 -L의 합계 sum만 알고 있다면, 이후 A[X]는 cnt*(X+1)+sum 으로 구할 수 있다. 따라서 cnt용 펜윅트리와, sum용 펜윅트리를 따로 둔다. 그럼 [L,R]에 대한 update는 cnt[L]+=1, cnt[R+1]-=1, sum[L]-=L, sum[R]+=L 이 된다. 이후 point query로 A[X]에 해당하는 cnt값과 sum값을 획득해서 cnt*(X+1)+sum 을 출력해주면 된다.

 

 

코드 (풀이 2) : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    int n;
    long[] cntBit, sumBit;

    private void update(long[] bit, int i, long diff) {
        while (i <= n) {
            bit[i] += diff;
            i += i&-i;
        }
    }

    private long query(long[] bit, int i) {
        long sum = 0;
        while (i > 0) {
            sum += bit[i];
            i -= i&-i;
        }
        return sum;
    }

    private void rangeUpdate(long[] bit, int i, int j, long diff) {
        update(bit, i, diff);
        update(bit, j+1, -diff);
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine());
        long[] arr = new long[n+1];
        cntBit = new long[n+1];
        sumBit = new long[n+1];
        StringTokenizer st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) {
            arr[i] = Integer.parseInt(st.nextToken());
        }

        int q = Integer.parseInt(br.readLine());
        StringBuilder sb = new StringBuilder();
        while (q-->0) {
            st = new StringTokenizer(br.readLine());
            switch (Integer.parseInt(st.nextToken())) {
                case 1:
                    int l = Integer.parseInt(st.nextToken());
                    int r = Integer.parseInt(st.nextToken());
                    rangeUpdate(cntBit, l, r, 1);
                    rangeUpdate(sumBit, l, r, -l);
                    break;
                case 2:
                    int x = Integer.parseInt(st.nextToken());
                    sb.append(arr[x] + (x+1)*query(cntBit, x) + query(sumBit, x)).append('\n');
                    break;
            }
        }
        System.out.print(sb);
    }

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

댓글