[グラフ] プリム法

Python でクラスカル法を実装します。 SciPy で最小全域木を求める時はクラスカル法を使っています。 scipy.sparse.csgraph.minimum_spanning_tree Vertex と Edge 与えられた...

プリム法はクラスカル法と同じ最小全域木を探すアルゴリズムです。

最短経路を探すダイクストラ法にとても良く似ています。

プリム法

プリム法とは、グラフ理論で重み付き連結グラフの最小全域木を求める最適化問題のアルゴリズムである。全域木(対象となるグラフの全頂点を含む辺の部分集合で構成される)のうち、その辺群の重みの総和が最小となる木を求めるものである。

出典: フリー百科事典『ウィキペディア(Wikipedia)』

クラスカル法では、素集合データ構造で複数の木構造を作り、閉路ができないように「辺」を選ぶことで、最終的に最小全域木を作りました。

プリム法は現在の最小全域木に隣接する「頂点」の中で、最小の辺となるものを選ぶことで最小全域木を作っていきます。

最小の頂点の選択にヒープを使うと、新しく辺が選ばれた際に追加される頂点の数だけヒープの再構成を行うので、2分ヒープを使った場合は計算量は\( O ( E \log (V) ) \) になります。

アルゴリズムと データ構造

正しさの証明は以下。

正しさの証明

また、具体的な流れは以下。

プリム法とクラスカル法の使い分け

When should I use Kruskal as opposed to Prim (and vice versa)?

基本的にはクラスカル法が良いようです。

クラスカル法の計算量は\( O (E \log V ) \)、プリム法はフィボナッチヒープを使うと \( O (V \log V) \) になるので、頂点の数に比して辺の数が多い稠密なグラフではプリム法が有利になります。

また、クラスカル法の計算量はソートによるので、既にソート済みの辺を扱う場合のようにソートの計算量を減らせる場合は、クラスカル法のほうが有利になります。

Python での実装

heapq のインポート

ヒープを使うためにインポートします。

heapq — ヒープキューアルゴリズム

import heapq

VertexEdge クラス

VertexEdge は別に定義しています。

ヒープの中で重さで最小の辺が取り出せるように、特殊メソッド __lt__ を設定します。

import heapq

class Vertex(object):

    def __init__(self, name):
        self.name = name
        self.visited = False
        self.adjacency_list = []


class Edge(object):

    def __init__(self, from_vertex, to_vertex, weight):
        self.from_vertex = from_vertex
        self.to_vertex = to_vertex
        self.weight = weight

    def __lt__(self, other_edge):
        return self.weight < other_edge.weight

プリム法

今いる頂点からまだ訪問していない頂点につながる最小の辺を、訪問していない頂点がなくなるまで選び続けます。

import heapq

class Vertex(object):

    def __init__(self, name):
        self.name = name
        self.visited = False
        self.adjacency_list = []


class Edge(object):

    def __init__(self, from_vertex, to_vertex, weight):
        self.from_vertex = from_vertex
        self.to_vertex = to_vertex
        self.weight = weight

    def __lt__(self, other_edge):
        return self.weight < other_edge.weight


class Prim(object):

    def __init__(self, vertex_list):
        self.unvisited_list = vertex_list
        self.edge_heap = []
        self.spanning_tree = []
        self.min_weight = 0

    def make_min_spanning_tree(self, start_vertex):

        current_vertex = start_vertex
        self.unvisited_list.remove(current_vertex)
        
        while self.unvisited_list:
            for edge in current_vertex.adjacency_list:
                if edge.to_vertex in self.unvisited_list:
                    heapq.heappush(self.edge_heap, edge)
            
            min_edge = heapq.heappop(self.edge_heap)

            if min_edge.to_vertex in self.unvisited_list:
                self.spanning_tree.append(min_edge)
                self.min_weight += min_edge.weight
                current_vertex = min_edge.to_vertex
                self.unvisited_list.remove(current_vertex)
    
    def show_min_spanning_tree(self):
        edges = []
        weight = 0        
        for edge in self.spanning_tree:
            edges.append(f'{edge.from_vertex.name}-{edge.to_vertex.name}')
            weight += edge.weight
        print(edges, 'weight:', weight)

テスト

クラスカル法と同じグラフの最小全域木を求めてみます。

import heapq

class Vertex(object):

    def __init__(self, name):
        self.name = name
        self.visited = False
        self.adjacency_list = []


class Edge(object):

    def __init__(self, from_vertex, to_vertex, weight):
        self.from_vertex = from_vertex
        self.to_vertex = to_vertex
        self.weight = weight

    def __lt__(self, other_edge):
        return self.weight < other_edge.weight


class Prim(object):

    def __init__(self, vertex_list):
        self.unvisited_list = vertex_list
        self.edge_heap = []
        self.spanning_tree = []
        self.min_weight = 0

    def make_min_spanning_tree(self, start_vertex):

        current_vertex = start_vertex
        self.unvisited_list.remove(current_vertex)
        
        while self.unvisited_list:
            for edge in current_vertex.adjacency_list:
                if edge.to_vertex in self.unvisited_list:
                    heapq.heappush(self.edge_heap, edge)
            
            min_edge = heapq.heappop(self.edge_heap)

            if min_edge.to_vertex in self.unvisited_list:
                self.spanning_tree.append(min_edge)
                self.min_weight += min_edge.weight
                current_vertex = min_edge.to_vertex
                self.unvisited_list.remove(current_vertex)
    
    def show_min_spanning_tree(self):
        edges = []
        weight = 0        
        for edge in self.spanning_tree:
            edges.append(f'{edge.from_vertex.name}-{edge.to_vertex.name}')
            weight += edge.weight
        print(edges, 'weight:', weight)


if __name__ == '__main__':
    a = Vertex('A')
    b = Vertex('B')
    c = Vertex('C')
    d = Vertex('D')
    e = Vertex('E')
    f = Vertex('F')
    g = Vertex('G')

    edge_ab = Edge(a, b, 7)
    edge_ad = Edge(a, d, 5)
    a.adjacency_list.extend([edge_ab, edge_ad])
    edge_ba = Edge(b, a, 7)
    edge_bc = Edge(b, c, 8)
    edge_bd = Edge(b, d, 9)
    edge_be = Edge(b, e, 7)
    b.adjacency_list.extend([edge_ba, edge_bc, edge_bd, edge_be])
    edge_cb = Edge(c, b, 8)
    edge_ce = Edge(c, e, 5)
    c.adjacency_list.extend([edge_cb, edge_ce])
    edge_da = Edge(d, a, 5)
    edge_db = Edge(d, b, 9)
    edge_de = Edge(d, e, 15)
    edge_df = Edge(d, f, 6)
    d.adjacency_list.extend([edge_da, edge_db, edge_de, edge_df])
    edge_eb = Edge(e, b, 7)
    edge_ec = Edge(e, c, 5)
    edge_ed = Edge(e, d, 15)
    edge_ef = Edge(e, f, 8)
    edge_eg = Edge(e, g, 9)
    e.adjacency_list.extend([edge_eb, edge_ec, edge_ed, edge_ef, edge_eg])
    edge_fd = Edge(f, d, 6)
    edge_fe = Edge(f, e, 8)
    edge_fg = Edge(f, g, 11)
    f.adjacency_list.extend([edge_fd, edge_fe, edge_fg])
    edge_ge = Edge(g, e, 9)
    edge_gf = Edge(g, f, 11)
    g.adjacency_list.extend([edge_ge, edge_gf])

    vertex_list = [a, b, c, d, e, f, g]
    
    prim = Prim(vertex_list)
    prim.make_min_spanning_tree(a)
    # ['A-D', 'D-F', 'A-B', 'B-E', 'E-C', 'E-G'] weight: 39
    prim.show_min_spanning_tree()

想定通りの結果になりました。