본문 바로가기
PS/BOJ

[자바, C++] 백준 10999 - 구간 합 구하기 2 (java cpp)

by Nahwasa 2022. 8. 15.

 문제 : boj10999


 

필요 알고리즘 개념

  • lazy propagation을 적용한 세그먼트 트리 혹은 펜윅 트리
    • 세그먼트 트리 + lazy propagation 혹은 range update, range query가 가능한 펜윅 트리를 알고 있어야 풀 수 있다.

※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

 


 

풀이

  세그먼트 트리 lazy propagation 혹은 range update, range query 펜윅 트리의 기본형태 문제이다. 펜윅 트리로 풀려면 작성해둔 펜윅 트리 글에서 '응용 3' 부분을 읽어보면 이 문제를 풀 수 있다.

 

 


 

코드(C++) : github

#include <iostream>
using namespace std;
#define ll long long

int n;
ll bit1[1000001] = {0,};
ll bit2[1000001] = {0,};

void update(int bitType, int idx, long diff) {
    ll* bit = bitType==1 ? bit1 : bit2;
    while (idx <= n) {
        bit[idx] += diff;
        idx += idx&-idx;
    }
}

void rangeUpdate(int a, int b, long diff) {
    update(1, a, diff);
    update(1, b+1, -diff);
    update(2, a, diff * (a-1));
    update(2, b+1, -diff * b);
}

ll getBitValue(int bitType, int idx) {
    ll* bit = bitType==1 ? bit1 : bit2;
    long answer = 0;
    while (idx > 0) {
        answer += bit[idx];
        idx -= idx&-idx;
    }
    return answer;
}

ll prefixSum(int idx) {
    return getBitValue(1, idx) * idx - getBitValue(2, idx);
}

ll query(int a, int b) {
    return prefixSum(b) - prefixSum(a-1);
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int m, k;
    cin >> n >> m >> k;
    m += k;

    for (int i = 1; i <= n; i++) {
        ll cur;
        cin >> cur;
        rangeUpdate(i, i, cur);
    }

    while (m--) {
        int op, a, b;
        ll v;

        cin >> op;
        switch (op) {
            case 1:
                cin >> a >> b >> v;
                rangeUpdate(a, b, v);
                break;
            case 2:
                cin >> a >> b;
                cout << query(a, b) << '\n';
                break;
        }
    }
}

 

 

코드(Java) : github

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
    long[] bit1, bit2;
    int n;

    private void update(int bitType, int idx, long diff) {
        long[] bit = bitType==1 ? bit1 : bit2;
        while (idx <= n) {
            bit[idx] += diff;
            idx += idx&-idx;
        }
    }

    private void rangeUpdate(int a, int b, long diff) {
        update(1, a, diff);
        update(1, b+1, -diff);
        update(2, a, diff * (a-1));
        update(2, b+1, -diff * b);
    }

    private long getBitValue(int bitType, int idx) {
        long[] bit = bitType==1 ? bit1 : bit2;
        long answer = 0;
        while (idx > 0) {
            answer += bit[idx];
            idx -= idx&-idx;
        }
        return answer;
    }

    private long prefixSum(int idx) {
        return getBitValue(1, idx) * idx - getBitValue(2, idx);
    }

    private long query(int a, int b) {
        return prefixSum(b) - prefixSum(a-1);
    }

    private void solution() throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        bit1 = new long[n+1];
        bit2 = new long[n+1];
        int mk = Integer.parseInt(st.nextToken()) + Integer.parseInt(st.nextToken());

        for (int i = 1; i <= n; i++) {
            rangeUpdate(i, i, Long.parseLong(br.readLine()));
        }

        StringBuilder sb = new StringBuilder();
        while (mk-->0) {
            st = new StringTokenizer(br.readLine());
            switch (Integer.parseInt(st.nextToken())) {
                case 1:
                    int a = Integer.parseInt(st.nextToken());
                    int b = Integer.parseInt(st.nextToken());
                    long diff = Long.parseLong(st.nextToken());
                    rangeUpdate(a, b, diff);
                    break;
                case 2:
                    int i = Integer.parseInt(st.nextToken());
                    int j = Integer.parseInt(st.nextToken());
                    sb.append(query(i, j)).append('\n');
                    break;
            }
        }
        System.out.print(sb);
    }

  public static void main(String[] args) throws Exception {
        new Main().solution();
    }
}

댓글