본문 바로가기
PS/BOJ

[자바] 백준 28251 - 나도리합 (java)

by Nahwasa 2023. 8. 11.

문제 : boj28251

 

 

필요 알고리즘

  • 수학, 분리 집합 (disjoint set, union-find)
    • 기본적인 아이디어는 수학적 직관이 좀 필요하고, 그 직관을 구현하기 위해 분리 집합 알고리즘을 사용해야 하는 문제이다.

※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

 

 

풀이

  4개의 수 a,b,c,d가 있다고 해보자.

a와 b를 합치면 ab, c와 d를 합치면 cd가 필요하다.

 

  그 후 ab와 cd를 합쳐보면 아래처럼 식이 나올꺼다.

ab + ac+ad+bc+bd + cd

여기서 빨간 글자 부분만 식을 전개해보면 아래처럼 된다.

 

ac+ad+bc+bd = a(c+d) + b(c+d) = (a+b)(c+d)

 

  그럼 최종적으로 ab + (a+b)(c+d) + cd 형태가 된다. 즉, 예를들어 x,y,z를 합친 전투력을 f(xyz) 처럼 표현하고, x,y,z 각각의 전투력을 합친걸 g(xyz) = x+y+z 라고 해보자.

 

이 경우 xyz 그룹과 abc 그룹을 합친다면

f(xyzabc) = f(xyz) + g(xyz)*g(abc) + f(abc) 처럼 구할 수 있다.

이걸 어떻게 찾았냐면, 그냥 공책에 ab와 cd를 합쳐보다보니 저렇게 나오길래 다른 경우도 해보니 맞아서 ㅗㅜㅑ를 외치며 구현했다. 내 경우 수학이 매우 약해서.. 얻어걸린 셈이다.

 

  코드에서 sum[]은 g()를 뜻하고, power[]는 f() 를 뜻한다. 근데 배열에서 위에서 얘기한 xyz나 abc를 어떻게 표현할 수 있을까? 그리고 문제에서는 그룹끼리 합치는게 아니라 나도리의 번호가 주어진다! 이걸 보고 어느 그룹에 속해있는지 어떻게 표현가능할까? 이걸 해결해주는게 분리 집합 알고리즘(union-find)이다.

 

  이하는 문제에서 주어진 a,b 를 입력받아 두 나도리 그룹을 합치고, 합쳐진 나도리 그룹의 전투력을 리턴하는 함수에 설명을 달아둔 것이다. 기본적으로는 union-find 알고리즘을 알고 있어야 이해할 수 있으므로 모른다면 그걸 먼저 익혀보자.

long solve(int a, int b) {
    a = find(a);	// a를 그룹번호로 변경
    b = find(b);	// b를 그룹번호로 변경
    if (a == b) return power[a];	// 동일한 그룹이라면 해당 그룹의 전투력 리턴

	// 이하 4줄은 그냥 분리집합 알고리즘을 효율적으로 돌리기 위한 부분이다.
    // 잘 모르겠으면 hi를 a, lo를 b라고 봐도 된다.
    int hi = parents[a] < parents[b] ? a:b;
    int lo = parents[a] < parents[b] ? b:a;
    parents[hi] += parents[lo];
    parents[lo] = hi;

	// 글에서 설명한대로 얘기해보면 power[hi]는 f(xyzabc)인 셈이고,
    // += 이므로 이미 f(xyz)에다가
    // power[lo]인 f(abc)를 더하고
    // sum[hi]*sum[lo]로 g(xyz)*g(abc)를 더한다. 즉, 글에서 설명한것과 동일하다.
    power[hi] += power[lo] + sum[hi] * sum[lo];
    
    // g(xyzabc) = g(xyz) + g(abc)
    sum[hi] += sum[lo];

	// return f(xyzabc)
    return power[hi];
}

 

 

코드 : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in), 1<<16);

    public static void main(String[] args) throws Exception {
        new Main().solution();
    }

    int[] parents;
    long[] sum, power;

    int find(int a) {
        if (parents[a] < 0) return a;
        return parents[a] = find(parents[a]);
    }

    long solve(int a, int b) {
        a = find(a);
        b = find(b);
        if (a == b) return power[a];

        int hi = parents[a] < parents[b] ? a:b;
        int lo = parents[a] < parents[b] ? b:a;
        parents[hi] += parents[lo];
        parents[lo] = hi;

        power[hi] += power[lo] + sum[hi] * sum[lo];
        sum[hi] += sum[lo];

        return power[hi];
    }

    public void solution() throws Exception {
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        int q = Integer.parseInt(st.nextToken());

        parents = new int[n];
        Arrays.fill(parents, -1);
        sum = new long[n];
        power = new long[n];

        st = new StringTokenizer(br.readLine());
        for (int i = 0; i < n; i++) sum[i] = Integer.parseInt(st.nextToken());

        StringBuilder sb = new StringBuilder();
        while (q-->0) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken()) - 1;
            int b = Integer.parseInt(st.nextToken()) - 1;
            sb.append(solve(a, b)).append('\n');
        }

        System.out.print(sb);
    }
}

 

댓글