본문 바로가기
PS/BOJ

[자바] 백준 27501 - RGB트리 (java)

by Nahwasa 2024. 5. 23.

목차

    문제 : boj27501

     

     

    필요 알고리즘

    • 트리 DP (트리 + 동적계획법)
      • 트리에서 DP를 사용해 푸는 문제이다.

    ※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

     

     

    풀이

      '백준 1149 RGB거리' 문제의 트리버전이라고 보면 된다. 그러니 저걸 안풀었다면 저것부터 풀어보자. 1149를 못푸는데 이 문제를 풀 수 있을 가능성은 없다.

     

      우선 간소화해서 생각해보기 위해, 이 문제를 1149번 문제처럼 트리긴한데 그냥 1차원이었다고 생각해보자. 즉 간선이 그냥 '1 2, 2 3, 3 4, 4 5' 처럼 입력되었다고 생각해보자.

     

      이 경우라면 dp[x][c]를 x 구간을 c 색상으로 했을 때 가능한 아름다움의 최댓값이라 설정할 수 있을꺼다. 그리고 arr[x][c]는 x구간이 c색상일 때의 아름다움이라고 하자. 그럼 아래와 같이 해결할 수 있을꺼다.

    for (int x = 1; x <= N; x++) {
        dp[x][0] = arr[x][0] + max(dp[x-1][1], dp[x-1][2]);
        dp[x][1] = arr[x][1] + max(dp[x-1][0], dp[x-1][2]);
        dp[x][2] = arr[x][2] + max(dp[x-1][0], dp[x-1][1]);
    }
    
    print(max(dp[N][0], dp[N][1], dp[N][2]);

     

     

     

      근데 위는 1차원이니깐 배열로 처리해도 되는데, 트리로 확장을 해야하니 동일한 동작을 하긴 하지만 순차접근이 아니라 dfs로 접근하는걸 고려해보자. 즉, 위의 코드는 그냥 반복문으로 x를 1부터 N까지 증가시키면서 진행했지만, dfs를 써서 다음과 같이 역으로 타고내려가면서 구해도 동일한 동작이다.

    int ans = 0;
    for (int i = 0; i < 3; i++) {
    	ans = max(ans, dfs(N, i));
    }
    print(ans);
    
    int dfs(int x, int 현재색상) {
    	int res = 0;
    	if (현재색상 == 0)
        	res = max(res, arr[x][현재색상] + max(dfs(x-1, 1), dfs(x-1, 2));
        else if (현재색상 == 1)
        	res = max(res, arr[x][현재색상] + max(dfs(x-1, 0), dfs(x-1, 2)); 
        else
        	res = max(res, arr[x][현재색상] + max(dfs(x-1, 0), dfs(x-1, 1));
            
        return dp[x][현재색상] = res + arr[x][현재색상];
    }

     

     

      그럼 1차원에서 했던 dfs를 그냥 트리에 적용시키면 된다. 루트노드를 시작점으로 해서, 모든 리프노드까지 타고 내려가면서 결과를 역으로 구해나가는거다. 이 때, '트리' 이므로 루트 노드는 뭘로 잡던 상관 없다. 그냥 아무 정점으로부터 시작해서 위 dfs를 해나가면 된다.

    int ans = 0;
    for (int i = 0; i < 3; i++) {
    	ans = max(ans, dfs(N, i));
    }
    print(ans);
    
    int dfs(int x, int 현재색상) {
    	int sum = 0;
        
        for (x와 인접한 정점 z를 순회하면서) {
            int res = 0;
            if (현재색상 == 0)
                res = max(res, arr[x][현재색상] + max(dfs(z, 1), dfs(z, 2));
            else if (현재색상 == 1)
                res = max(res, arr[x][현재색상] + max(dfs(z, 0), dfs(z, 2)); 
            else
                res = max(res, arr[x][현재색상] + max(dfs(z, 0), dfs(z, 1));
           	
            sum += res;
        }
        return dp[x][현재색상] = sum + arr[x][현재색상];
    }

     

     

      즉, 1자로 되어있었으므로 그냥 이전 것만 보면 됬던거에서, 중간에 분기쳐지는 곳들이 생김으로 인해 각 가지의 결과를 더해주는 형태로 바뀌었을 뿐이다. 코드상에서 find() 함수가 이렇게 dfs하면서 트리를 dp하는 역할이다.

     

     

      위의 과정을 통해 최대 아름다움을 구했다면, 이제 해당 최대 수치가 가능하게 하는 전구의 색을 출력해줘야 한다. 우린 위의 과정을 통해 임의의 정점 a를 루트로 정해두고, dfs를 통해 최대 아름다움을 구했다. 이 때 최대 아름다움일 때의 a 정점의 색상은 쉽게 알 수 있을꺼다. dp[a][0]가 최대 아름다움인데, dp[a][1]인지, dp[a][2]인지만 비교해보면 된다.

     

      그럼 이후 인접한 정점도 마찬가진데, dp[a][1]이 최대 아름다움이었다고 한다면, a의 색상은 G였을꺼다. 그 다음 정점은 수식상 max(dp[인접한정점][0], dp[인접한정점][2]) 였을 테고, dp[인접한정점][0]이 더 컸다고 한다면 인접한 정점의 색상은 R이었을꺼다. 즉, 루트 노드로 정한 임의의 정점 a의 색상을 결정했다면, 그 인접한 정점의 색상을 정할 수 있고(애초에 수식을 통해 구한거니, 역으로 수식을 적용하면 된다), 인접한 정점의 인접한 정점의 색상도 정할 수 있는거다.

     

      그럼 중간에 동일한 수치가 있다면 어떨까? 예를들어 dp[a][0]은 50, dp[a][1]은 40, dp[a][2]는 50 이었다. 그럼 a는 R과 B 중 뭘로 해야할까? 상관없다. 추가 조건 없이 '가능한 답이 여러 가지라면 아무거나 출력한다' 라고 했으므로 어느쪽을 구하든 모든 정점의 색상을 정할 수 있다. 그냥 인접한 정점끼리 색상이 다르기만 하면 된다. 코드에서 '// tracking color' 이하의 부분이 이 동작을 하는 부분이다. find() 함수의 재귀를 쓴 dfs처럼, Stack을 쓴 tracking color 주석 부분도 dfs를 한거다. 그냥 함수 하나 더 쓰는김에 스택써서 dfs 하고 싶었다. 

     

      로직이 이해됬다고 해도 구현이 좀 까다로울 순 있다. 트리dp 관련해서 생각하기 어렵지 않으면서도, 역추적까지 해야 해서 추천할만한 문제라 생각된다.

     

     

    코드 : github

    import java.io.BufferedReader;
    import java.io.InputStreamReader;
    import java.util.*;
    
    import static java.lang.Math.*;
    
    public class Main {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in), 1 << 16);
        public static void main(String[] args) throws Exception {
            new Main().solution();
        }
    
        private static final char[] COLOR = new char[]{'R', 'G', 'B'};
        private int n;
        private List<Integer>[] edges;
        private int[][] arr;
        private int[][] dp;
        private boolean[] v;
    
        private void solution() throws Exception {
            // init
            n = Integer.parseInt(br.readLine());
            dp = new int[n+1][3];
            v = new boolean[n+1];
            edges = new List[n+1];
            for (int i = 0; i <= n; i++) edges[i] = new ArrayList<>();
            for (int i = 1; i < n; i++) {
                StringTokenizer st = new StringTokenizer(br.readLine());
                int a = Integer.parseInt(st.nextToken());
                int b = Integer.parseInt(st.nextToken());
                edges[a].add(b);
                edges[b].add(a);
            }
    
            arr = new int[n+1][3];
            for (int i = 1; i <= n; i++) {
                StringTokenizer st = new StringTokenizer(br.readLine());
                arr[i][0] = Integer.parseInt(st.nextToken());
                arr[i][1] = Integer.parseInt(st.nextToken());
                arr[i][2] = Integer.parseInt(st.nextToken());
            }
            
            
            // find maximum answer
            int ans = 0;
            edges[0].add(1);
            for (int i = 0; i < 3; i++) ans = max(ans, find(0, i));
            
            
            // tracking color
            char[] selected = new char[n+1];
            Stack<int[]> stk = new Stack<>();
            for (int i = 0; i < 3; i++) {
                if (dp[0][i] == ans) {
                    stk.push(new int[]{0, i});
                    break;
                }
            }
    
            while (!stk.isEmpty()) {
                int[] cur = stk.pop();
                int idx = cur[0];
                int color = cur[1];
                selected[idx] = COLOR[color];
    
                for (int next : edges[idx]) {
                    if (v[next]) continue;
                    v[next] = true;
    
                    int nextColor = -1;
                    int nextMax = 0;
                    for (int i = 0; i < 3; i++) {
                        if (color == i) continue;
    
                        if (nextMax < dp[next][i]) {
                            nextMax = dp[next][i];
                            nextColor = i;
                        }
                    }
    
                    stk.push(new int[]{next, nextColor});
                }
            }
    
            
            // print answer
            StringBuilder sb = new StringBuilder();
            sb.append(ans).append('\n');
            for (int i = 1; i <= n; i++) sb.append(selected[i]);
            System.out.println(sb);
        }
    
        private int find(final int idx, final int color) {
            if (dp[idx][color] != 0)
                return dp[idx][color];
    
            int sum = 0;
    
            for (int next : edges[idx]) {
                if (v[next]) continue;
                v[next] = true;
    
                int max = 0;
                for (int i = 0; i < 3; i++) {
                    if (color == i) continue;
    
                    max = max(max, find(next, i));
                }
                sum += max;
    
                v[next] = false;
            }
    
            return dp[idx][color] = sum + arr[idx][color];
        }
    }
    
    class Stack<T> {
        private final Deque<T> dq = new ArrayDeque<>();
        boolean isEmpty(){return dq.isEmpty();}
        T pop() {return dq.pollFirst();}
        void push(T e) {dq.addFirst(e);}
    }

     

    댓글