Algorithm/알고리즘 스터디(2021.07)

[Algorithm] Kruskal Algorithm (크루스칼 알고리즘)

binaryJournalist 2021. 9. 12. 23:16
반응형

 

 

https://www.inflearn.com/course/algorithm-%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98-%EC%8B%A4%EC%8A%B5/lecture/12348

 

알고리즘의 개요와 실습 환경 구축 - 인프런 | 학습 페이지

지식을 나누면 반드시 나에게 돌아옵니다. 인프런을 통해 나의 지식에 가치를 부여하세요....

www.inflearn.com

 

 

인프런 알고리즘 강의 18강에 대한 리뷰다

 

 

크루스칼 알고리즘은 가장 적은 비용으로 모든 노드를 연결해주는 알고리즘이다.

 

다른 말로 최소 비용 신장 트리(Spanning Tree)라고도 한다.

 

신장 트리란 하나의 그래프가 있을 때 모든 노드를 포함하면서 사이클이 존재하지 않는 부분 그래프를 의미한다.

(트리의 성립 조건은 모든 노드가 포맣되어 서로 연결되면서 사이클이 존재하지 않는 것이다.)

크루스칼 알고리즘은 가능한 한 최소한의 비용으로 신장트리를 찾는 대표적인 알고리즘이다.

(*쿠르수칼 알고리즘은 그리디 알고리즘으로 분류된다고 한다.)

 

모든 길은 로마로 통한다면 가장 빠르고 저렴하게 돌지 않고 모든 도시에서 로마로 가는 방법을 찾는 알고리즘

 

보통 노드와 간선으로 표현되는데 사람에 따라 달리 표현되기도 한다.

 

*노드 = 정점 = 도시
*간선 = 거리 = 비용 = 노드 개수 - 1

 

 

출처: 위키피디아

 

 

 

 

 

먼저 모든 노드를 연결한 뒤에 노드가 전부 하나로 연결되어있는지 확인해야 한다.

그리고 그 중 가장 비용이 적은 간선을 포함시킨다.

구체적인 알고리즘은 다음과 같다.

 

  1. 간선 데이터를 비용에 따라 오름차순으로 정렬한다.
  2. 간선을 하나씩 확인하면서 현재의 간선에서 사이클이 발생하는지 확인한다.
    • 사이클이 발생하지 않는 경우 최소 신장 트리에 포함시킨다.
    • 사이클이 발생하는 경우 최소 신장 트리에서 제외시킨다.
  3. 모든 간선에 대하여 2번의 과정을 반복한다.

 

노드가 모두 연결되어 사이클이 발생하면 안되는데 사이클 발생 여부는 union-find 연산을 이용하여 확인한다.

(사이클이 발생하면 간선을 포함시키지 않고 사이클이 없으면 간선을 포함시킴)

 

*union 집합을 활용한 사이클 판별

간선을 하나씩 확인하면서 두 노드가 포함되어 있는 집합을 합치는 과정을 반복하여 사이클을 판별한다.

  1. 각 간선을 확인하여 두 노드의 루트 노드를 확인한다
    • 루트 노드가 서로 다르면 두 노드에 대하여 union 연산을 수행한다.
    • 루트 노드가 서로 같다면 사이클이 발생한 것으로 판단한다.
  2. 모든 간선에 대하여 1번의 과정을 반복한다.

 

*union-find 를 활용한 사이클 판별 소스

 

# 원소가 속한 집합 찾기
def find_parent(parent, x):
    # 루트 노드를 찾을 때까지 재귀적으로 호출
    if parent[x] != x:
        parent[x] = find_parent(parent, parent[x])
    return parent[x]

# 두 원소가 속한 집합을 합치기
def union_parent(parent, a, b):
    a = find_parent(parent, a)
    b = find_parent(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

def isCycled(parent, a, b):
    return find_parent(parent, a) == find_parent(parent, b)

def main() :
    v, e = map(int, input().split())
    parent = [0] * (v + 1) # 부모 테이블 초기화
    for i in range(1, v + 1):
        parent[i] = i # 부모 노드를 자기 자신으로 초기화
    
    for i in range(e):
        a, b = map(int, input().split())
        if isCycled(parent, a, b):
            print("사이클이 발생했습니다")
            break
        else:
            union_parent(parent, a, b)
    
if __name__ == "__main__":
    # execute only if run as a script
    main()

 

 

 

크루스칼 알고리즘은 간선의 개수가 E개일 때, O(ElogE)의 시간 복잡도를 가진다.

크루스칼 알고리즘 중 시간이 가장 오래 걸리는 부분은 간선을 정렬하는 작업이기 때문에 간선 데이터를 정렬했을 때의 시간복잡도는 간선 개수에 따라 O(ElogE)이기 때문이다.

 

 

 

 

** C++

 

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

// 재귀적으로 부모 노드를 찾음 
int getParent(int parent[], int x) {
	if (parent[x] == x) return x;
	return parent[x] = getParent(parent, parent[x]);
}

// 두 부모 노드를 합치는 함수
int unionParent(int parent[], int a, int b) {
	a = getParent(parent, a);
	b = getParent(parent, b);
	if (a < b) parent[b] = a;
	else parent[a] = b;
}

// 같은 부모인지 확인하는 함수
int findParent(int parent[], int a, int b) {
	a = getParent(parent, a);
	b = getParent(parent, b);
	if (a == b) return 1;
	return 0;
}

// 간선 클래스
class Edge {
	public:
		int node[2];
		int distance;
		// 초기화 
		Edge(int a, int b, int distance) {
			this->node[0] = a;
			this->node[1] = b;
			this->distance = distance;
		}
		bool operator <(Edge &edge) {
			return this->distance < edge.distance;
		}
}; 

int main(void) {
	int n = 7;
	int m = 11;
	vector<Edge> v;
	v.push_back(Edge(1, 7, 12));
	v.push_back(Edge(1, 4, 28));
	v.push_back(Edge(1, 2, 67));
	v.push_back(Edge(1, 5, 17));
	v.push_back(Edge(2, 4, 24));
	v.push_back(Edge(2, 5, 62));
	v.push_back(Edge(3, 5, 20));
	v.push_back(Edge(3, 6, 37));
	v.push_back(Edge(4, 7, 13));
	v.push_back(Edge(5, 6, 45));
	v.push_back(Edge(5, 7, 73));
	
	// 간선 비용을 기준으로 오름차순 정렬 
	sort(v.begin(), v.end());
	
	// 각 정점이 포함된 그래프가 어디인지 저장
	int parent[n];
	for (int i=0; i < n; i++) {
		parent[i] = i;
	} 
	int sum = 0;
	for (int i=0; i < v.size(); i++) {
		// 사이클이 발생하지 않는 경우 그래프에 포함시킴 
		if (!findParent(parent, v[i].node[0] - 1, v[i].node[1] - 1)) {
			sum += v[i].distance;
			unionParent(parent, v[i].node[0] - 1, v[i].node[1] - 1);
		}
	}
	
	printf("%d\n", sum);
}

 

 

 

** Python

 

# 원소가 속한 집합 찾기
def find_parent(parent, x):
    # 루트 노드를 찾을 때까지 재귀적으로 호출
    if parent[x] != x:
        parent[x] = find_parent(parent, parent[x])
    return parent[x]

# 두 원소가 속한 집합을 합치기
def union_parent(parent, a, b):
    a = find_parent(parent, a)
    b = find_parent(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

def isCycled(parent, a, b):
    return find_parent(parent, a) == find_parent(parent, b)

def main() :
    n, m = 7, 11 # 노드의 개수와 간선(union 연산)의 개수
    parent = [0] * (n + 1) # 부모 테이블 초기화
    for i in range(1, n + 1):
        parent[i] = i # 부모 노드를 자기 자신으로 초기화
    
    # 모든 간선을 담을 리스트와 최종 비용을 담을 변수
    edges = []
    result = 0
    
    # 모든 간선에 대한 정보를 입력받기 (cost, a, b)
    # 비용순으로 정리하기 위해 cost를 가장 앞으로 뺌
    edges.append((12, 1, 7))
    edges.append((28, 1, 4))
    edges.append((67, 1, 2))
    edges.append((17, 1, 5))
    edges.append((24, 2, 4))
    edges.append((62, 2, 5))
    edges.append((20, 3, 5))
    edges.append((37, 3, 6))
    edges.append((13, 4, 7))
    edges.append((45, 5, 6))
    edges.append((73, 5, 7))
    
    edges.sort()
    
    for edge in edges:
        cost, a, b = edge
        if not isCycled(parent, a, b):
            union_parent(parent, a, b)
            result += cost
    print(result)
    
if __name__ == "__main__":
    # execute only if run as a script
    main()

 

 

** Javascript

 

 

// 원소가 속한 집합 찾기
function find_parent(parent, x) {
    // 루트 노드를 찾을 때까지 재귀적으로 호출
    if (parent[x] != x) parent[x] = find_parent(parent, parent[x]);
    return parent[x];
}

// 두 원소가 속한 집합을 합치기
function union_parent(parent, a, b) {
    a = find_parent(parent, a);
    b = find_parent(parent, b);
    if (a < b) {
        parent[b] = a;
    }
    else {
        parent[a] = b;
    }
}

function isCycled(parent, a, b) {
    return find_parent(parent, a) == find_parent(parent, b);
}

function main() {
    let n = 7, m = 11; // 노드의 개수와 간선(union 연산)의 개수
    let parent = new Array(n + 1).fill(0) // 부모 테이블 초기화
    for (let i = 1; i <= n; i++) parent[i] = i; // 부모 노드를 자기 자신으로 초기화
    
    // 모든 간선을 담을 리스트와 최종 비용을 담을 변수
    let edges = []
    let result = 0
    
    // 모든 간선에 대한 정보를 입력받기 (cost, a, b)
    edges.push([12, 1, 7])
    edges.push([28, 1, 4])
    edges.push([67, 1, 2])
    edges.push([17, 1, 5])
    edges.push([24, 2, 4])
    edges.push([62, 2, 5])
    edges.push([20, 3, 5])
    edges.push([37, 3, 6])
    edges.push([13, 4, 7])
    edges.push([45, 5, 6])
    edges.push([73, 5, 7])
    
    edges.sort((a, b) => a[0] - b[0]);
    
    for (const edge of edges) {
        const [cost, a, b] = edge;
        if (!isCycled(parent, a, b)) {
            union_parent(parent, a, b);
            result += cost;
        }
    }
    console.log(result);
}

main();

 

 

** Java

 

import java.util.*;
public class Main
{
    private static int findParent(int[] parent, int x) {
        if (parent[x] != x) {
            parent[x] = findParent(parent, parent[x]);
        }
        return parent[x];
    }
    
    private static void unionParent(int[] parent, int a, int b) {
        a = findParent(parent, a);
        b = findParent(parent, b);
        if (a < b) {
            parent[b] = a;
        } else {
            parent[a] = b;
        }
    }
    
    private static boolean isCycled(int[] parent, int a, int b) {
        return findParent(parent, a) == findParent(parent, b);
    }
    
	public static void main(String[] args) {
		int n = 7;
		int m = 11;
		int[] parent = new int[n + 1];
		for (int i = 1; i <= n; i++) {
		    parent[i] = i;
		}
		int[][] edges = {
		    {12, 1, 7}, {28, 1, 4}, {67, 1, 2}, {17, 1, 5},
		    {24, 2, 4}, {62, 2, 5}, {20, 3, 5}, {37, 3, 6},
		    {13, 4, 7}, {45, 5, 6}, {73, 5, 7}
		};
		int result = 0;
		
		Arrays.sort(edges, (a, b) -> {
		    return a[0] - b[0];
		});
		
		for (int i = 0; i < edges.length; i++) {
		    if(!isCycled(parent, edges[i][1], edges[i][2])) {
		        unionParent(parent, edges[i][1], edges[i][2]);
		        result += edges[i][0];
		    }
		}
		System.out.println(result);
	}
}
반응형