본문 바로가기
PS/BOJ

백준 23801 자바 - 두 단계 최단 경로 2 (BOJ 23801 JAVA)

by Nahwasa 2021. 12. 6.

문제 : boj23801

 

 

  한 점에서 모든 점으로의 최단거리를 알 수 있는 알고리즘 중에 다익스트라 알고리즘이 있다.

이 문제의 경우, x에서 p개의 점 중 적어도 하나를 방문한 후 z까지 가야 한다. 그럼 다익스트라로 x부터 다른 모든 점으로의 거리를 찾아도 답을 구할 수 없다.

 

  이 문제의 경우 z점으로 들어가는 모든 지점의 거리도 알 수 있어야 한다. 이것도 다익스트라를 응용해서 할 수 있는데, 무방향 그래프라면 z에서 시작하는 다익스트라를 돌려도 그 결과가 모든 점에서부터 z로 향하는 최단거리와 동일하다. (참고로 방향이 있는 그래프일 경우에도 다익스트라로 모든 점에서부터 한 점으로의 최단거리를 구할 수 있는데, 모든 간선의 방향을 역방향으로 바꾼 후 다익스트라를 돌리면 된다.)

 

1. X에서 모든 점으로의 최단 거리를 구한다. ->X에서 시작하는 다익스트라로 구함

2. 모든 점에서 Z로의 최단 거리를 구한다. ->Z에서 시작하는 다익스트라로 구함

3. P개의 Y를 입력받아, i번째 Y가 Yi 라고 한다면 [X에서 Yi의 최단거리 + Yi에서 z의 최단거리]가 가장 작은 것을 출력해주면 된다. 그래야 적어도 하나의 Y를 지난 후 Z에 도착한 것이 된다.

 

총 시간복잡도는 대략 O(ElogV + ElogV + P)가 된다. (X에서 다익스트라 + Z에서 다익스트라 + P개의 Y에 대한 최단거리 확인)

 

 

 

코드 : github

import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.PriorityQueue;

class Edge {
    int n;
    int d;
    public Edge(int n, int d) {
        this.n = n;
        this.d = d;
    }
}

class Pos implements Comparable<Pos> {
    int n;
    long dist;
    public Pos(int n, long dist) {
        this.n = n;
        this.dist = dist;
    }

    @Override
    public int compareTo(Pos o) {
        if (this.dist > o.dist) return 1;
        else if (this.dist < o.dist) return -1;
        return 0;
    }
}

public class Main extends FastInput {
    private static final long DIST_LIMIT = 300000l * 1000000l + 1;

    private long[] getMinDist(int n, int start, ArrayList<Edge>[] edges) {
        long[] dist = new long[n+1];
        Arrays.fill(dist, DIST_LIMIT);

        PriorityQueue<Pos> pq = new PriorityQueue<>();
        pq.add(new Pos(start, 0));
        dist[start] = 0;

        while (!pq.isEmpty()) {
            Pos cur = pq.poll();
            if (cur.dist > dist[cur.n]) continue;
            for (Edge next : edges[cur.n]) {
                long nextDist = cur.dist + next.d;
                if (nextDist < dist[next.n]) {
                    dist[next.n] = nextDist;
                    pq.add(new Pos(next.n, nextDist));
                }
            }
        }
        return dist;
    }

    private void solution() throws Exception {
        int n = nextInt();
        int m = nextInt();

        ArrayList<Edge>[] edges = new ArrayList[n+1];
        for (int i = 1; i <= n; i++) {
            edges[i] = new ArrayList<>();
        }

        while(m-->0) {
            int u = nextInt();
            int v = nextInt();
            int w = nextInt();
            edges[u].add(new Edge(v, w));
            edges[v].add(new Edge(u, w));
        }
        int x = nextInt();
        int z = nextInt();

        long[] fromX = getMinDist(n, x, edges);
        long[] toZ = getMinDist(n, z, edges);

        long answer = DIST_LIMIT;
        int p = nextInt();
        while(p-->0) {
            int y = nextInt();
            if (fromX[y] == DIST_LIMIT || toZ[y] == DIST_LIMIT) continue;

            long distSum = fromX[y] + toZ[y];
            if (answer > distSum)
                answer = distSum;
        }
        System.out.println(answer == DIST_LIMIT ? -1 : answer);
    }

    public static void main(String[] args) throws Exception {
        initFI();
        new Main().solution();
    }
}

class FastInput {
    private static final int DEFAULT_BUFFER_SIZE = 1 << 16;
    private static DataInputStream inputStream;
    private static byte[] buffer;
    private static int curIdx, maxIdx;

    protected static void initFI() {
        inputStream = new DataInputStream(System.in);
        buffer = new byte[DEFAULT_BUFFER_SIZE];
        curIdx = maxIdx = 0;
    }

    protected static int nextInt() throws IOException {
        int ret = 0;
        byte c = read();
        while (c <= ' ') c = read();
        do {
            ret = ret * 10 + c - '0';
        } while ((c = read()) >= '0' && c <= '9');
        return ret;
    }

    private static byte read() throws IOException {
        if (curIdx == maxIdx) {
            maxIdx = inputStream.read(buffer, curIdx = 0, DEFAULT_BUFFER_SIZE);
            if (maxIdx == -1) buffer[0] = -1;
        }
        return buffer[curIdx++];
    }
}

댓글