본문 바로가기
PS/BOJ

[자바] 백준 2632 - 피자판매 (java)

by Nahwasa 2023. 7. 20.

목차

    문제 : boj2632

     

     

    필요 알고리즘

    • 누적 합 (prefix sum), 해시를 사용한 집합과 맵
      • 누적합을 통해 구간의 합을 빠르게 구하고, 그 값을 빠르게 찾을 수 있으면 된다. 다만 원형이므로 약간의 아이디어가 필요하다.

    ※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

     

     

    풀이

      우선 누적합 알고리즘을 모르면 풀기 매우 어려워지고, 어렵지 않으면서 엄청 자주 쓰이는 알고리즘이므로 모른다면 이 글을 먼저 읽자.

     

      문제를 좀 단순화해서, 그냥 두 배열 A, B가 주어졌다고 해보자. 그리고 찾고하는건 A의 모든 연속된 구간의 합(공집합 포함)과 B의 모든 연속된 구간의 합(공집합 포함)을 합친게 G인 경우의 수를 구하는거다.

     

      이 경우 A와 B가 충분히 크기가 작다면, 다음의 로직으로 어렵지 않게 구할 수 있다. A배열의 크기가 X, B배열의 크기가 Y라고 하겠다.

     

     

    1. A에 대한 누적합 배열을 만든다. O(X)

     

    2. B에 대한 누적합 배열을 만든다. O(Y)

     

    3. 누적합 배열의 전체 구간에 대해 누적합을 구해서, 해당 누적합이 나온 횟수를 어딘가에 저장해둔다. 이 때 구하면서 나온 누적합이 G와 같다면 그 횟수도 세어둔다. 누적합의 최대치가 작다면 배열로 둬도 되겠으나 크다면 해싱해두는게 편하다. 내 경우엔 해싱으로(HashMap) 처리했다. O(X^2)

    -> 예를들어 A = [1,2,1] 이었다면 누적합 배열은 [1,3,4]이 될 것이고, 모든 연속된 구간에 대한 누적합은 각각 1이 한번([1]), 3이 한번([3]), 4가 두번([1,3], [4]), 7이 한번([3,4]), 8이 한번([1,3,4]) 나온다.

     

    4. B도 마찬가지로 모든 연속된 구간에 대한 합을 구하면서, 나온 누적합이 G와 같다면 카운팅해둔다. 그리고 이 때 'G - 현재 나온 누적합'이 '3'에서 저장해둔 해시에 존재한다면. 그 횟수만큼 카운팅한다. 혹은 배열로 처리했다면 이분탐색으로 동일하게 횟수를 찾아 카운팅한다. O(Y^2) x O(해싱이라면 1 || 이분탐색이라면 log(X^2))

     

     

     

      여기까진 사실 어느정도 알고리즘을 풀어봤다면 웰-논이라 쉽게 아이디어를 떠올릴 수 있다. 문제는 이 문제는 고정된 배열 A, B가 아니라 원형으로 생긴 피자 A, B 라는 점이다. 즉, 피자가 순서대로 1,2,3,4,5번 조각이라고 한다면 그냥 배열이라면 [1,2,3,4] 또는 [2,3,4,5] 이렇게가 전체가 아닌 최대 조각이다. 근데 원형이므로, [3,4,5,1], [4,5,1,2], ... 이런식으로도 가능해져버린다는 점이 문제다.

     

      근데 사실 메모리가 충분하다면 이걸 해결할 아주 간단한 아이디어가 있다. 그냥 애초에 입력받을 때, [1,2,3,4,5,1,2,3,4,5] 로 2배로 받아둬버리면 끝난다. 원형에 대한 연산이 위에서 설명한 A, B 배열에 대한 로직으로 바꼈다. 그럼 위에서 설명한걸 2배로 늘려서 해버리면 된다! 하나만 주의하면 되는데, 모든 연속된 구간을 찾을 때 이미 썼던 조각을 다시 쓰면 안되므로(이미 먹은 피자를 다시 먹을 순 없으므로) 늘어나는 범위만 제한해주면 된다. (e.g. 코드의 'for (int j = i; j < i+a-1; j++)' 부분에 보면 [i, i+a-1) 까지로 범위를 제한함을 볼 수 있다.)

     

     

    1, 2.

    for (int i = 1; i <= a*2; i++) arrA[i] += arrA[i-1];
    for (int i = 1; i <= b*2; i++) arrB[i] += arrB[i-1];

     

    3.

    Map<Integer, Integer> cnt = new HashMap<>();
    if (arrA[a] == g) answer++;
    cnt.put(arrA[a], 1);
    
    for (int i = 1; i <= a; i++) {
        for (int j = i; j < i+a-1; j++) {
            int rangeSum = arrA[j]-arrA[i-1];
            if (rangeSum == g) answer++;
            if (rangeSum >= g) continue;
            cnt.put(rangeSum, cnt.getOrDefault(rangeSum, 0) + 1);
        }
    }

     

    4. 

    if (arrB[b] == g) answer++;
    answer += cnt.getOrDefault(g-arrB[b], 0);
    
    for (int i = 1; i <= b; i++) {
        for (int j = i; j < i+b-1; j++) {
            int rangeSum = arrB[j]-arrB[i-1];
            if (rangeSum == g) answer++;
    
            answer += cnt.getOrDefault(g-rangeSum, 0);
        }
    }

     

     

    코드 : github

    import java.io.BufferedReader;
    import java.io.InputStreamReader;
    import java.util.HashMap;
    import java.util.Map;
    import java.util.StringTokenizer;
    
    public class Main {
    
        static BufferedReader br = new BufferedReader(new InputStreamReader(System.in), 1<<16);
    
        public static void main(String[] args) throws Exception {
            new Main().solution();
        }
    
        private void solution() throws Exception {
            int g = Integer.parseInt(br.readLine());
            StringTokenizer st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
    
            int[] arrA = new int[a*2+1];
            for (int i = 1; i <= a; i++) arrA[i] = arrA[a+i] = Integer.parseInt(br.readLine());
            int[] arrB = new int[b*2+1];
            for (int i = 1; i <= b; i++) arrB[i] = arrB[b+i] = Integer.parseInt(br.readLine());
    
            for (int i = 1; i <= a*2; i++) arrA[i] += arrA[i-1];
            for (int i = 1; i <= b*2; i++) arrB[i] += arrB[i-1];
    
            int answer = 0;
    
            Map<Integer, Integer> cnt = new HashMap<>();
            if (arrA[a] == g) answer++;
            cnt.put(arrA[a], 1);
    
            for (int i = 1; i <= a; i++) {
                for (int j = i; j < i+a-1; j++) {
                    int rangeSum = arrA[j]-arrA[i-1];
                    if (rangeSum == g) answer++;
                    if (rangeSum >= g) continue;
                    cnt.put(rangeSum, cnt.getOrDefault(rangeSum, 0) + 1);
                }
            }
    
            if (arrB[b] == g) answer++;
            answer += cnt.getOrDefault(g-arrB[b], 0);
    
            for (int i = 1; i <= b; i++) {
                for (int j = i; j < i+b-1; j++) {
                    int rangeSum = arrB[j]-arrB[i-1];
                    if (rangeSum == g) answer++;
    
                    answer += cnt.getOrDefault(g-rangeSum, 0);
                }
            }
    
            System.out.println(answer);
        }
    }

     

    댓글