본문 바로가기
PS/BOJ

[자바] 백준 28251 - 나도리합 (java)

by Nahwasa 2023. 8. 11.

목차

    문제 : boj28251

     

     

    필요 알고리즘

    • 수학, 분리 집합 (disjoint set, union-find)
      • 기본적인 아이디어는 수학적 직관이 좀 필요하고, 그 직관을 구현하기 위해 분리 집합 알고리즘을 사용해야 하는 문제이다.

    ※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

     

     

    풀이

      4개의 수 a,b,c,d가 있다고 해보자.

    a와 b를 합치면 ab, c와 d를 합치면 cd가 필요하다.

     

      그 후 ab와 cd를 합쳐보면 아래처럼 식이 나올꺼다.

    ab + ac+ad+bc+bd + cd

    여기서 빨간 글자 부분만 식을 전개해보면 아래처럼 된다.

     

    ac+ad+bc+bd = a(c+d) + b(c+d) = (a+b)(c+d)

     

      그럼 최종적으로 ab + (a+b)(c+d) + cd 형태가 된다. 즉, 예를들어 x,y,z를 합친 전투력을 f(xyz) 처럼 표현하고, x,y,z 각각의 전투력을 합친걸 g(xyz) = x+y+z 라고 해보자.

     

    이 경우 xyz 그룹과 abc 그룹을 합친다면

    f(xyzabc) = f(xyz) + g(xyz)*g(abc) + f(abc) 처럼 구할 수 있다.

    이걸 어떻게 찾았냐면, 그냥 공책에 ab와 cd를 합쳐보다보니 저렇게 나오길래 다른 경우도 해보니 맞아서 ㅗㅜㅑ를 외치며 구현했다. 내 경우 수학이 매우 약해서.. 얻어걸린 셈이다.

     

      코드에서 sum[]은 g()를 뜻하고, power[]는 f() 를 뜻한다. 근데 배열에서 위에서 얘기한 xyz나 abc를 어떻게 표현할 수 있을까? 그리고 문제에서는 그룹끼리 합치는게 아니라 나도리의 번호가 주어진다! 이걸 보고 어느 그룹에 속해있는지 어떻게 표현가능할까? 이걸 해결해주는게 분리 집합 알고리즘(union-find)이다.

     

      이하는 문제에서 주어진 a,b 를 입력받아 두 나도리 그룹을 합치고, 합쳐진 나도리 그룹의 전투력을 리턴하는 함수에 설명을 달아둔 것이다. 기본적으로는 union-find 알고리즘을 알고 있어야 이해할 수 있으므로 모른다면 그걸 먼저 익혀보자.

    long solve(int a, int b) {
        a = find(a);	// a를 그룹번호로 변경
        b = find(b);	// b를 그룹번호로 변경
        if (a == b) return power[a];	// 동일한 그룹이라면 해당 그룹의 전투력 리턴
    
    	// 이하 4줄은 그냥 분리집합 알고리즘을 효율적으로 돌리기 위한 부분이다.
        // 잘 모르겠으면 hi를 a, lo를 b라고 봐도 된다.
        int hi = parents[a] < parents[b] ? a:b;
        int lo = parents[a] < parents[b] ? b:a;
        parents[hi] += parents[lo];
        parents[lo] = hi;
    
    	// 글에서 설명한대로 얘기해보면 power[hi]는 f(xyzabc)인 셈이고,
        // += 이므로 이미 f(xyz)에다가
        // power[lo]인 f(abc)를 더하고
        // sum[hi]*sum[lo]로 g(xyz)*g(abc)를 더한다. 즉, 글에서 설명한것과 동일하다.
        power[hi] += power[lo] + sum[hi] * sum[lo];
        
        // g(xyzabc) = g(xyz) + g(abc)
        sum[hi] += sum[lo];
    
    	// return f(xyzabc)
        return power[hi];
    }

     

     

    코드 : github

    import java.io.BufferedReader;
    import java.io.InputStreamReader;
    import java.util.Arrays;
    import java.util.StringTokenizer;
    
    public class Main {
        static BufferedReader br = new BufferedReader(new InputStreamReader(System.in), 1<<16);
    
        public static void main(String[] args) throws Exception {
            new Main().solution();
        }
    
        int[] parents;
        long[] sum, power;
    
        int find(int a) {
            if (parents[a] < 0) return a;
            return parents[a] = find(parents[a]);
        }
    
        long solve(int a, int b) {
            a = find(a);
            b = find(b);
            if (a == b) return power[a];
    
            int hi = parents[a] < parents[b] ? a:b;
            int lo = parents[a] < parents[b] ? b:a;
            parents[hi] += parents[lo];
            parents[lo] = hi;
    
            power[hi] += power[lo] + sum[hi] * sum[lo];
            sum[hi] += sum[lo];
    
            return power[hi];
        }
    
        public void solution() throws Exception {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int n = Integer.parseInt(st.nextToken());
            int q = Integer.parseInt(st.nextToken());
    
            parents = new int[n];
            Arrays.fill(parents, -1);
            sum = new long[n];
            power = new long[n];
    
            st = new StringTokenizer(br.readLine());
            for (int i = 0; i < n; i++) sum[i] = Integer.parseInt(st.nextToken());
    
            StringBuilder sb = new StringBuilder();
            while (q-->0) {
                st = new StringTokenizer(br.readLine());
                int a = Integer.parseInt(st.nextToken()) - 1;
                int b = Integer.parseInt(st.nextToken()) - 1;
                sb.append(solve(a, b)).append('\n');
            }
    
            System.out.print(sb);
        }
    }

     

    댓글