본문 바로가기
PS/BOJ

[자바] 백준 10830 - 행렬 제곱 (java)

by Nahwasa 2022. 8. 2.

 문제 : boj10830


 

필요 알고리즘 개념

  •  행렬 곱셈 (수학)
    • 행렬끼리 곱하는 방법을 알아야 한다.
  • 분할정복을 이용한 거듭제곱
    • 분할정복을 이용해 거듭제곱을 최적화하는 방법을 알아야 한다.
  • 나머지 연산의 분배법칙
    • 나머지 연산의 분배법칙을 알고 있어야 풀 수 있다.

※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

 


 

풀이

 

우선 행렬끼리 곱하는 방법을 모른다면, 구글링을 통해 알아보자.

기초 수학이므로 반드시 알고 있어야 한다. N이 3인 경우를 예로 들면 다음과 같다.

 

 


 

다음으로 분할정복을 이용한 거듭제곱을 알아야 한다.

아마 다음은 이미 알고 있을 것이다.

예를들어 거듭제곱되는 수치가 짝수일 때와 홀수일 때를 예로들면 다음과 같다.

A^4 = (A^2)^2 (짝수일때)
A^5 = A * A^4 = A * (A^2)^2 (홀수일때)

 

만약 A^8이 있다면 원래는 A*A*A*A*A*A*A*A 으로 곱셈 연산을 7번 해야 하지만, A^8을 ((A^2)^2)^2 으로 변경하면 A*A=A', A'^2=A'' 이라 할 시 최종적으로 A'을 구하는데에 A*A로 곱셉 한번, A''=A'*A'을 구하는데도 마찬가지로 곱셈 한번, 최종적으로 A''*A''을 구하는데 곱셈 한 번이 들어간다. 즉 7번의 연산이 3번의 연산으로 줄어든다!

 

즉, 거듭제곱을 위와 같이 계산한다면 원래 N번의 연산이 O(logN)으로 줄어든다. 여기서 분할정복을 활용한다는 말은 위의 공식을 코드로 짤 때 분할정복을 활용하면 쉽게 짤 수 있기 때문이다. 예를들어 위의 방법을 활용해 자바로 짠 A^N을 구하는 코드는 다음과 같다.

private static long pow(long a, long n) {
    if (n == 1)
        return a;
    long tmp = pow(a, n/2);
    if (n%2==0)
        return tmp*tmp;
    else
        return a*tmp*tmp;
}

 

또한 단순히 정수의 거듭제곱 계산에만 활용한다면 큰 의미는 없겠지만, 이 방법을 활용하면 행렬의 거듭제곱 등 다른 부분에도 응용할 수 있다. 예를들어 피보나치수의 경우 원래 알고 있는 f(n) = f(n-1) + f(n-2) (n>=2, f(0)=0, f(1)=1) 를 계속해서 구하는 것 보다 더 빠르게 다음의 행렬 거듭제곱으로 구할 수 있는 공식이 있다.

여기에도 마찬가지로 행렬의 거듭제곱을 계산하기 위해 위에서 말한 분할정복을 활용해 거듭제곱을 계산할 수 있다. 자바 코드로 구현해보면 다음과 같다. (MOD의 경우는 값이 너무 커지므로 해당 값으로 나눈 나머지를 출력한 것이다.)

private static final long MOD = 1000007l;

    private long[][] matrixMult(long[][] a, long[][] b) {
        long[][] arr = new long[2][2];
        arr[0][0] = (a[0][0]*b[0][0]%MOD + a[0][1]*b[1][0]%MOD)%MOD;
        arr[1][0] = (a[0][0]*b[0][1]%MOD + a[0][1]*b[1][1]%MOD)%MOD;
        arr[0][1] = (a[1][0]*b[0][0]%MOD + a[1][1]*b[1][0]%MOD)%MOD;
        arr[1][1] = (a[1][0]*b[0][1]%MOD + a[1][1]*b[1][1]%MOD)%MOD;
        return arr;
    }

    private long[][] fibo(long n) {
        if (n == 1) {
            long[][] arr = {{1,1},{1,0}};
            return arr;
        }
        long[][] tmp = fibo(n/2);
        if (n%2==1) {
            return matrixMult(matrixMult(tmp, tmp), fibo(1));
        } else {
            return matrixMult(tmp, tmp);
        }
    }

  위의 로직을 사용할 시, 1,000,000,000,000,000,000 번째 피보나치조차도 O(logN)을 하면 사실 60번의 행렬 곱셈만으로 구할 수 있다(2^60 = 1,152,921,504,606,846,976). 진짜 저걸 다 곱해보면 대략 317년정도 걸릴 연산(대략 1억번 연산에 1초정도로 잡았다.)이 로직 좀 변경했다고 0.1초도 안되서 끝나는 것이다.

 

 

  그럼 위에서 설명한 내용을 이 문제에 대입시켜보자. 위에서 2x2 행렬의 곱셈까진 설명했는데, 이번엔 NxN 행렬을 곱해야 하므로 matrixMult 부분이 다음과 같이 변경되면 될 것이다.

private int[][] matrixMult(int[][] a, int[][] b) {
    int[][] arr = new int[n][n];
    for (int i = 0; i < n; i++) {   // row of a
        for (int j = 0; j < n; j++) {   // column of b
            for (int x = 0; x < n; x++) {
                arr[i][j] += a[i][x]*b[x][j];
            }
            arr[i][j] %= MOD;
        }
    }
    return arr;
}

 


 

나머지 연산의 분배법칙도 알아야 한다.

  우선 나머지 연산의 분배법칙부터 보자.

  즉, 어차피 덧셈으로 이루어져 있다면 매번 나머지 연산을 진행해도 결과에 영향을 끼치지 않는다는 얘기이다. 내 경우엔 코드에서 보듯이 매번 matrixMult를 할 때 마다 arr[i][j] %= MOD; 를 해준다. matrixMult가 불릴 때 마다 a와 b에 포함된 모든 원소가 항상 1000이하라면, 각 행렬 곱셈 결과의 최대값은 이 문제에서 5x1000x1000이 된다. 그걸 다시 나머지 연산의 분배법칙에 따라 1000으로 나눈 나머지만 남겨둔다면 언제나 matrixMult의 인자 a,b의 모든 원소가 1000 이하임이 보장된다.

 


 

주의할 점

  처음 입력으로 들어온 행렬의 각 원소는 1000이하의 수이다. 그리고 B의 최소값은 1이다. 그러므로 다음과 같은 예시를 주의해야한다.

2 1
1000 1000
1000 1000

 

  이 경우, 행렬 곱셈이 일어나지 않으므로 그냥 그대로 출력할 수 있으나 문제 조건에서 1000으로 나눈 나머지를 출력해줘야 하므로 답은 아래와 같이 출력되야 한다.

0 0
0 0

 


 

코드 : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    private static final int MOD = 1000;
    private int[][] baseMatrix;
    private int n;

    private int[][] matrixMult(int[][] a, int[][] b) {
        int[][] arr = new int[n][n];
        for (int i = 0; i < n; i++) {   // row of a
            for (int j = 0; j < n; j++) {   // column of b
                for (int x = 0; x < n; x++) {
                    arr[i][j] += a[i][x]*b[x][j];
                }
                arr[i][j] %= MOD;
            }
        }
        return arr;
    }

    private int[][] matrixPow(long b) {
        if (b == 1) {
            return baseMatrix;
        }
        int[][] tmp = matrixPow(b/2);
        int[][] tmpPow2 = matrixMult(tmp, tmp);
        return b%2==0 ? tmpPow2 : matrixMult(tmpPow2, matrixPow(1));
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        long b = Long.parseLong(st.nextToken());
        int[][] matrix = new int[n][n];
        for (int i = 0; i < n; i++) {
            st = new StringTokenizer(br.readLine());
            for (int j = 0; j < n; j++) {
                matrix[i][j] = Integer.parseInt(st.nextToken());
                matrix[i][j] %= MOD;
            }
        }

        baseMatrix = matrix;
        int[][] answer = matrixPow(b);
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                sb.append(answer[i][j]).append(' ');
            }
            sb.append('\n');
        }
        System.out.print(sb);
    }
    public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

댓글