1일 1PS 87일차!

📚 문제


[골드 3] - 학교 탐방하기

💡 풀이 과정


  • MST(최소 신장 트리) 문제이다.
  • Java 를 풀 때는 Node 라는 클래스를 추가한다.
  • 간선을 모두 추가하고 최대, 최소로 sort 하기 보다는 Heap 을 사용해 정렬을 대신한다.
  • MST 시 사이클 존재 여부는 Union-Find 를 통해 검증해낸다.

📌포인트


  • parent
    • Union-Find 를 두 번 사용하기에 꼭 초기화 해줘야한다.
  • class Node
    • Node 라는 클래스를 생성해 사용하면 좋다.

💻 코드



import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

class Main {
    private static Queue<Node> minHeap = new PriorityQueue<>();
    private static Queue<Node> maxHeap = new PriorityQueue<>((v1, v2) -> v2.weight - v1.weight);
    static class Node implements Comparable<Node> {
        int v1, v2, weight;

        public Node(int v1, int v2, int weight) {
            this.v1 = v1;
            this.v2 = v2;
            this.weight = weight;
        }

        @Override
        public int compareTo(Node o) {
            return weight - o.weight;
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());

        int N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());

        //Graph 생성
        // 내리막길 가중치는 1, 오르막길 가중치는 K^2
        for (int i = 0; i < M + 1; i++) {
            // 입구와 1번 정점간의 관계가 가장 먼저 주어진다.
            st = new StringTokenizer(br.readLine());
            int v1 = Integer.parseInt(st.nextToken());
            int v2 = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());

            minHeap.add(new Node(v1, v2, weight));
            maxHeap.add(new Node(v1, v2, weight));
        }

        //최소 피로도 계산 MST
        int[] parent = new int[N + 1];
        for (int i = 0; i <= N; i++){
            parent[i] = i;
        }

        int cnt = 0;
        int maxCost = 0;
        while((cnt <= N) && !minHeap.isEmpty()){
            Node node = minHeap.poll();
            int v1 = node.v1;
            int v2 = node.v2;
            int weight = node.weight;

            if (find(parent, v1) != find(parent, v2)){
                union(parent, v1, v2);
                cnt++;
                if(weight == 0){
                    maxCost += 1;
                }
            }
        }

        for (int i = 0; i <= N; i++){
            parent[i] = i;
        }

        cnt = 0;
        int minCost = 0;
        while((cnt <= N) && !maxHeap.isEmpty()){
            Node node = maxHeap.poll();
            int v1 = node.v1;
            int v2 = node.v2;
            int weight = node.weight;

            if (find(parent, v1) != find(parent, v2)){
                union(parent, v1, v2);
                cnt++;
                if(weight == 0){
                    minCost += 1;
                }
            }
        }

        int answer = (int) Math.pow(maxCost, 2) - (int) Math.pow(minCost, 2);
        System.out.println(answer);
    }

    public static int find(int[] parent, int x){
        if (parent[x] == x){
            return x;
        }
        return parent[x] = find(parent, parent[x]);
    }
    public static void union(int[] parent, int x, int y){
        x = find(parent, x);
        y = find(parent, y);

        if (x != y){
            parent[y] = x;
        }
    }
}


업데이트:

댓글남기기