본문 바로가기
PS/BOJ

[자바] 백준 14268 - 회사 문화 2 (java)

by Nahwasa 2022. 11. 4.

 문제 : boj14268


 

필요 알고리즘 개념

  • lazy propagation을 적용한 세그먼트 트리 혹은 펜윅 트리
    • 세그먼트 트리 + lazy propagation 혹은 range update가 가능한 펜윅 트리를 알고 있어야 풀 수 있다.
  • 오일러 경로 테크닉, DFS
    • 알면 생각하기 좋은데, 바로 생각 못할만한 개념은 아니다(내 경우에도 처음 오일러 경로 테크닉을 접했을 때 풀고보니 저런 알고리즘이었다.). 어떠한 노드에서 그 아래로 내려가는 경로들을 1차원으로 펴서 연속된 구간을 생성하는 개념이다. 구할 때 DFS를 사용한다.

※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

 


 

풀이

  i번째 직원이 칭찬을 들을 경우, 부하로 연쇄적으로 칭찬을 하게 된다. 즉, '1 i w' 쿼리는 point update로 보이지만 실제론 여러명에 대한 update에 해당한다. 다만 그대로 사용하려면 연속된 구간이 아니라 띄엄띄엄 떨어져 있는 여러 개의 노드에 대한 update이다. 그렇다면 만약 연속된 구간에 대한 update로 변경할 수 있다고 한다면 어떨까?

 

  그렇게 가능하다면 range update(어떠한 연속된 구간에 칭찬을 한다) + point query(특정 직원이 칭찬을 받은 정도를 획득한다.)가 된다. 그럼 일반적으로 세그먼트 트리의 lazy propagation 응용을 사용하면 해결할 수 있다. 내 경우엔 펜윅트리를 사용해 해결했다. 펜윅트리를 모르거나, 펜윅 트리로 range update + point query를 처리하는 방법을 모른다면 이 글을 참고해보자(응용 2가 이 풀이에 적용되는 풀이이다.). lazy propagation 혹은 위에서 얘기한 펜윅 트리를 사용해 특정 구간에 대해 +@를 해주고, point query로 답을 내는건 해당 알고리즘을 사용한 기본 동작이므로 별도로 풀이하지 않겠다.

 

  그럼 이제 '1 i w' 워 같은 쿼리를 범위 쿼리로 변경하는 방법을 알아보자. 이걸 오일러 경로 테크닉이라고 하는데, 다음과 같이 상사 -> 부하로 간선을 연결한 그래프를 생각해보자. 입력의 두번째 줄이  '-1 1 1 1 4' 와 같이 들어왔다면 다음과 같은 그래프가 된다. 노드의 번호는 직원 번호를 뜻한다.

 

  하고자 하는건, dfs 방문 순서대로 새로운 번호를 매기면서 각 노드에 칭찬이 들어왔을 경우 새로운 번호를 기준으로 연속된 범위를 지정하려는 것이다. 지정된 범위가 [A, B]라고 하자. A는 당연히 dfs로 방문해 해당 노드에 새로 부여된 번호에 해당될 것이다. B는 해당 노드가 범위를 끼치는 부하들에 해당하게 된다. 그럼 다음 그림처럼 진행하면 된다. 코드의 initNumAndRange()를 참고해보자.

 

 

 

 

 

  모든 노드에 연속된 번호로 구성된 [A, B] 범위가 생겼다. 따라서 '1 i w' 형태의 point update는 이제 연속된 구간에 대한 range update가 되고, '2 i' 또한 i에 해당하는 변경된 수(A)만 알 수 있다면 point query로 구할 수 있다.

 


 

코드 : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;

public class Main {
    int n, num = 1;
    int[] bit, mapping, rangeEnd;
    ArrayList<Integer>[] sub;

    private int query(int idx) {
        idx = mapping[idx];
        int sum = 0;
        while (idx > 0) {
            sum += bit[idx];
            idx -= idx&-idx;
        }
        return sum;
    }

    private void update(int idx, int diff) {
        while (idx <= n) {
            bit[idx] += diff;
            idx += idx&-idx;
        }
    }

    private void rangeUpdate(int s, int e, int diff) {
        update(s, diff);
        update(e+1, -diff);
    }

    private void rangeUpdate(int idx, int diff) {
        int s = mapping[idx];
        int e = rangeEnd[s];
        if (e == 0)
            return;
        rangeUpdate(s, e, diff);
    }

    private int initNumAndRange(int idx) {
        int end = (mapping[idx] = num++);
        if (sub[idx] == null)
            return rangeEnd[end] = end;

        for (int next : sub[idx]) {
            end = initNumAndRange(next);
        }
        rangeEnd[mapping[idx]] = end;
        return end;
    }

    private void init(int n) {
        this.n = n;
        sub = new ArrayList[n+1];
        mapping = new int[n+1];
        rangeEnd = new int[n+1];
        bit = new int[n+1];
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        init(Integer.parseInt(st.nextToken()));
        int m = Integer.parseInt(st.nextToken());

        st = new StringTokenizer(br.readLine());
        st.nextToken();
        for (int i = 2; i <= n; i++) {
            int cur = Integer.parseInt(st.nextToken());
            if (sub[cur] == null)
                sub[cur] = new ArrayList<>();
            sub[cur].add(i);
        }

        initNumAndRange(1);

        StringBuilder sb = new StringBuilder();
        while (m-->0) {
            st = new StringTokenizer(br.readLine());
            switch (st.nextToken().charAt(0)) {
                case '1':
                    rangeUpdate(Integer.parseInt(st.nextToken()), Integer.parseInt(st.nextToken()));
                    break;
                case '2':
                    sb.append(query(Integer.parseInt(st.nextToken()))).append('\n');
            }
        }
        System.out.print(sb);
    }

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

댓글0