본문 바로가기
PS/BOJ

백준 2191 자바 - 들쥐의 탈출 (BOJ 2191 JAVA)

by Nahwasa 2022. 3. 20.

문제 : boj2191

 

 

  이 문제에서 뭐' 속도가 빠른 쥐가 더 멀리 가야만한다'와 같은 가중치적인 부분은 들어있지 않다. N개의 정점을 M개의 정점에 최대한 많이 매칭 시키기만 하면 되는 문제이다.

 

1. 그래프 형태로 전처리

  문제에서 제시된 값들만 가지고는 뭔가 풀어보기가 힘들다. 따라서 우선 그래프 형태로 변경해보자. 각 쥐의 현재 x, y 좌표와 각 땅굴의 x, y 좌표를 알고 있다. a번째 쥐의 좌표를 Xa, Ya라 하고, b번째 땅굴의 좌표를 Xb, Yb라 하겠다. 그럼 쥐가 땅굴로 달리기 전 a번째 쥐와 b번째 땅굴의 거리는 다음과 같다.

  그리고 쥐는 초당 V만큼 움직이고, S초까지 들어가야 하므로 다음 식이 만족된다면 a번째 쥐에서 b번째 땅굴로 안전하게 들어갈 수 있다. ('단, 들쥐가 도착하는 시간이 정확히 S인 경우에도 그 들쥐는 도망칠 수 있는 것으로 간주' 이므로 '<=' 이다.)

 

  따라서 위 식을 만족한다면 a->b 형태의 간선을 만들어볼 수 있다. 예를들어 다음과 같은 입력을 확인해보자.

3 4 2 5
1.0 1.0
15.0 15.0
30.0 30.0
0.0 0.0
8.0 8.0
22.0 22.0
23.0 23.0

  우선 V*S = 10 이다. 따라서 (Xa-Xb)^2 + (Ya-Yb)^2이 100 이하면 간선을 연결할 수 있다. 이하 간선을 그림으로 나타냈다. 그리고 그려보면 알겠지만, 이 문제에서 나온 걸 그래프로 그리게 되면 이분 그래프가 나온다.

  TMI로, 코드에서는 실수 오차를 줄이기 위해 최대한 나누기나 루트같은걸 안쓰는게 정확도에 좋다. 따라서 양 변을 제곱해서 다음과 같이 푸는게 더 좋다.

 

 

2. 그럼 이제 이분 그래프가 만들어졌으니 최대한 많이 매칭시켜보자.

  이미 이분 그래프가 나왔으므로, 네트워크 플로우 중 이분 그래프에 대해 매칭시키는 bipartite matching 알고리즘을 적용해서 풀 수 있다. 이에 대한 설명은 여기(유투브 - MIT OpenCourseWare - Bipartite Matching) 등을 참고하자. 다른 문제이긴 하지만 작성한 글 중 이 글에서 대강의 매칭 과정을 확인할 수 있다.

 

  어떤 쥐가 어떤 굴에 들어갔냐는 중요하지 않고, 최대한 매칭 시킨 후 남는 쥐의 수만 구하면 'n-남은 쥐의 수'가 답이다.

 

 

코드 : github

import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;

public class Main extends FastInput {
    boolean[] v;
    int[] matched;
    int m;
    ArrayList<Integer>[] edges;

    private boolean matching(int from) {
        for (int i = 0; i < edges[from].size(); i++) {
            int to = edges[from].get(i);
            if (v[to]) continue;
            v[to] = true;
            if (matched[to] == -1 || matching(matched[to])) {
                matched[to] = from;
                return true;
            }
        }
        return false;
    }

    private void solution() throws Exception {
        int n = nextInt();
        m = nextInt();
        int limit = nextInt()*nextInt();
        limit *= limit;
        edges = new ArrayList[n];
        for (int i = 0; i < n; i++) edges[i] = new ArrayList<>();
        double[][] tmp = new double[n][2];
        for (int i = 0; i < n; i++) {
            tmp[i][0] = nextDouble();
            tmp[i][1] = nextDouble();
        }
        for (int i = 0; i < m; i++) {
            double x = nextDouble();
            double y = nextDouble();
            for (int j = 0; j < n; j++) {
                double dist = (x-tmp[j][0])*(x-tmp[j][0])+(y-tmp[j][1])*(y-tmp[j][1]);
                if (dist <= limit)
                    edges[j].add(i);
            }
        }

        int cnt = 0;
        matched = new int[m];
        Arrays.fill(matched, -1);
        for (int i = 0; i < n; i++) {
            v = new boolean[m];
            if (matching(i))
                cnt++;
        }
        System.out.println(n-cnt);
    }

    public static void main(String[] args) throws Exception {
        initFI();
        new Main().solution();
    }
}

class FastInput {
    private static final int DEFAULT_BUFFER_SIZE = 1 << 16;
    private static DataInputStream inputStream;
    private static byte[] buffer;
    private static int curIdx, maxIdx;

    protected static void initFI() {
        inputStream = new DataInputStream(System.in);
        buffer = new byte[DEFAULT_BUFFER_SIZE];
        curIdx = maxIdx = 0;
    }

    protected static int nextInt() throws IOException {
        int ret = 0;
        byte c = read();
        while (c <= ' ') c = read();
        do {
            ret = ret * 10 + c - '0';
        } while ((c = read()) >= '0' && c <= '9');
        return ret;
    }

    protected static double nextDouble() throws IOException {
        double ret = 0, div = 1;
        byte c = read();
        while (c <= ' ') c = read();
        boolean neg = (c == '-');
        if (neg) c = read();
        do {
            ret = ret * 10 + c - '0';
        } while ((c = read()) >= '0' && c <= '9');
        if (c == '.') while ((c = read()) >= '0' && c <= '9') ret += (c - '0') / (div *= 10);
        if (neg) return -ret;
        return ret;
    }

    private static byte read() throws IOException {
        if (curIdx == maxIdx) {
            maxIdx = inputStream.read(buffer, curIdx = 0, DEFAULT_BUFFER_SIZE);
            if (maxIdx == -1) buffer[0] = -1;
        }
        return buffer[curIdx++];
    }
}

댓글