본문 바로가기
PS/BOJ

[자바] 백준 2820 - 자동차 공장 (java)

by Nahwasa 2022. 8. 15.

 문제 : boj2820


 

필요 알고리즘 개념

  •  lazy propagation을 적용한 세그먼트 트리 혹은 펜윅 트리
    • 세그먼트 트리 + lazy propagation 혹은 range update가 가능한 펜윅 트리를 알고 있어야 풀 수 있다.
  • 오일러 경로 테크닉
    • 내 경우엔 이런게 있는줄 모르고 풀고보니 내가 한게 이거였다. 그러니 필요한 알고리즘 개념은 아니다.

※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

 


 

풀이

1. a의 모든 부하의 월급을 x만큼 증가시킨다 = range update

2. a의 월급을 출력한다 = point query

 

  범위 업뎃과 단일 값 획득 문제이다. 일반적으로 세그먼트 lazy propagation을 사용해 풀게 되는데, 구간을 잘 나눌 수 있다면 작성해둔 펜윅 트리 글에서 '응용 2'를 가지고 풀 수 있다.

 

  1번 상근이의 경우 상사가 없으므로, 1번을 루트로 해서 그래프를 그려보자. 예제 입력 2를 그려보겠다. 2,3,4번은 1의 부하이고, 5는 4의 부하이다.

  원하는 것은 저 그래프를 배열 형태로 바꿔서(위상 정렬 하듯이 1줄로 쭉 나오게), 'a의 모든 부하의 월급을 x만큼 증가한다'를 '[A, B] 구간의 값을 x만큼 증가한다'. 처럼 변경해보려 한다. 즉, 입력으로 들어온 그래프 형태를 1차원 배열 형태로 납작하게 변경하는 것이다. 왜냐하면 그래프 그대로 두면 매번 탐색하면서 증가시켜야 하니 O(N)이 필요하고. 1차원에 대해 [A, B] 구간의 update로 변경한다면 세그먼트 트리의 lazy propagation이나, 펜윅 트리의 range update 응용으로 O(logN)에 풀 수 있기 때문이다.

 

  방법은 간단한데, 1번 노드부터 시작해서 dfs를 돌려주면서 새로 만나는 노드 순서대로 번호를 붙여주면 된다. 그게 [A, B]에서 A에 해당한다. 그리고 재귀적으로 구현했을 때, 자신을 통해 시작된 재귀함수가 생성하는 최대 A값이 B에 해당하게 된다. 물론 해당 노드에서 진행 가능한 간선이 여러개라면 어딜 먼저 가더라도 상관없다. 과정은 다음과 같이 된다.

1
2
3
4
5

 

  위의 과정을 거쳐서 각 정점의 [A, B]가 정해진다. 이러한 방식을 '오일러 경로 테크닉'이라고 한다. 그럼 이제 a의 모든 부하의 월급을 x만큼 증가시키라고 할 경우, [A, B] 구간에 대해 x만큼 증가시키는 range update가 된다. 예를들어 4의 모든 부하의 월급을 5만큼 증가시키라고 한다면 정점 4의 [A, B]는 [3, 4] 이므로 [3, 4] 구간을 5만큼 증가시켜주면 된다. 주의점은 정점 번호와 A, B는 서로 다른 값이다. A, B는 임의로 만들어진 번호이므로 매핑해줘야 하는것을 헷갈리면 안된다. 만약 직원 2의 월급을 출력하라고 한다면 2의 A는 5 이므로 만들어둔 세그먼트나 펜윅트리 자료구조에서는 5의 값을 봐야 한다.

 


 

코드 : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;

public class Main {
    int n, num = 1;
    int[] bit, mapping, rangeEnd, salary;
    ArrayList<Integer>[] sub;

    private int query(int idx) {
        idx = mapping[idx];
        int sum = 0;
        while (idx > 0) {
            sum += bit[idx];
            idx -= idx&-idx;
        }
        return sum;
    }

    private void update(int idx, int diff) {
        while (idx <= n) {
            bit[idx] += diff;
            idx += idx&-idx;
        }
    }

    private void rangeUpdate(int s, int e, int diff) {
        update(s, diff);
        update(e+1, -diff);
    }

    private void rangeUpdate(int idx, int diff) {
        int s = mapping[idx];
        int e = rangeEnd[s];
        if (e == 0)
            return;
        rangeUpdate(s+1, e, diff);
    }

    private int initNumAndRange(int idx) {
        int end = (mapping[idx] = num++);
        if (sub[idx] == null)
            return end;

        for (int next : sub[idx]) {
            end = initNumAndRange(next);
        }
        rangeEnd[mapping[idx]] = end;
        return end;
    }

    private void initBIT() {
        for (int i = 1; i <= n; i++) {
            rangeUpdate(mapping[i], mapping[i], salary[i]);
        }
    }

    private void init(int n) {
        this.n = n;
        sub = new ArrayList[n+1];
        salary = new int[n+1];
        mapping = new int[n+1];
        rangeEnd = new int[n+1];
        bit = new int[n+1];
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        init(Integer.parseInt(st.nextToken()));
        int m = Integer.parseInt(st.nextToken());

        salary[1] = Integer.parseInt(br.readLine());
        for (int i = 2; i <= n; i++) {
            st = new StringTokenizer(br.readLine());
            salary[i] = Integer.parseInt(st.nextToken());
            int cur = Integer.parseInt(st.nextToken());
            if (sub[cur] == null)
                sub[cur] = new ArrayList<>();
            sub[cur].add(i);
        }

        initNumAndRange(1);
        initBIT();

        StringBuilder sb = new StringBuilder();
        while (m-->0) {
            st = new StringTokenizer(br.readLine());
            switch (st.nextToken().charAt(0)) {
                case 'p':
                    rangeUpdate(Integer.parseInt(st.nextToken()), Integer.parseInt(st.nextToken()));
                    break;
                case 'u':
                    sb.append(query(Integer.parseInt(st.nextToken()))).append('\n');
            }
        }
        System.out.print(sb);
    }

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

댓글