[Python] 2分ヒープ (2)

前回からの続きです。

2分ヒープ 2分ヒープを理解して、Pythonで実装します。 まず、heap とは、「積み重なった山のようなもの」です。a heap of stones と言えば、石が山のように積み重なっているものです。 ヒープソートは、この山のように積み重なったものの一番上...

Python では、heapq によりheapがモジュールとして提供されているので、普段はこちらを使います。

heapq — ヒープキューアルゴリズム

ここでは、学習のため、前回の内容に沿って、Python で heap を実装します。

max heap を考えます。

データ構造

heap のデータ構造ですが、配列はリストで表します。

また、リストの初期状態を0にすることで、heap の index が1から始まるようにします。

これにより、前回のindexが1から始まることと整合性を保ちます。

class Heap(object):
    def __init__(self):
        # リストの最初を0にすることで、heap の index が 1 始まるように調整
        self.heap_list = [0]
        self.size = 0

    def print_heap(self):
        print(self.heap_list[1:])

ノードの挿入

ノードを配列の一番最後に挿入し、親の値が自分より小さい場合は、入れ替えを行います。

def get_parent(self, index):
        if index <= 1:
            print('2以上の値を入力して下さい。')
            return
        parent = self.heap_list[index // 2]
        return parent
    
    def set_parent(self, index, value):
        self.heap_list[index // 2] = value

    def insert(self, value):
        self.heap_list.append(value)
        self.size += 1
        index = self.size
        while index // 2 > 0:
            if self.heap_list[index] > self.get_parent(index):
                # 親の方が小さいときは入れ替えを行う。
                tmp = self.get_parent(index)
                self.set_parent(index, self.heap_list[index])
                self.heap_list[index] = tmp
            else:
                break
            index = index // 2

前回の挿入の例に従い、以下に60を挿入します。

if __name__ == "__main__":
    
    heap = Heap()
    heap.insert(50)
    heap.insert(30)
    heap.insert(20)
    heap.insert(15)
    heap.insert(10)
    heap.insert(8)
    heap.insert(16)
    heap.print_heap()
     # [50, 30, 20, 15, 10, 8, 16]
    heap.insert(60)
    heap.print_heap()
    # [60, 50, 20, 30, 10, 8, 16, 15]

60の挿入が max heap の条件を満たすようにちゃんと行われました。

ノードの削除

リストの一番上の値を削除し、リストの一番最後の値をリストの一番最初に持ってきます。

その後、左の子と右の子と値を比較し、大きいほうの値の子と自身の値を比較し、子の方が大きい場合は、入れ替えを行います。この入れ替えは、percolate_down 関数として定義しおきます。

    def get_left_child(self, index):
        left_child = self.heap_list[2 * index]
        return left_child
    
    def get_right_child(self, index):
        right_child = self.heap_list[2 * index + 1]
        return right_child

    def get_index_minimum_child(self, index):
        # 右の子が無い場合
        if index * 2 + 1 > self.size:
            return index * 2
        else:
            if self.get_left_child(index) > self.get_right_child(index):
                return index * 2
            else:
                return index * 2 + 1
    
    def percolate_down(self, index):
        while (index * 2) <= self.size:
            index_minimum_child = self.get_index_minimum_child(index)
            if self.heap_list[index] < self.heap_list[index_minimum_child]:
                tmp = self.heap_list[index]
                self.heap_list[index] = self.heap_list[index_minimum_child]
                self.heap_list[index_minimum_child] = tmp 
            index = index_minimum_child

    def delete(self):
        return_value = self.heap_list[1]
        self.heap_list[1] = self.heap_list[self.size]
        self.size -= 1
        # リストの最後を削除
        self.heap_list.pop()
        index = 1
        self.percolate_down(index)
        return return_value

前回の例に従い、以下から50を削除します。

if __name__ == "__main__":

    heap = Heap()
    heap.insert(50)
    heap.insert(30)
    heap.insert(20)
    heap.insert(15)
    heap.insert(10)
    heap.insert(8)
    heap.insert(16)
    heap.print_heap()
    # [50, 30, 20, 15, 10, 8, 16]
    val = heap.delete()
    heap.print_heap()
    # [30, 16, 20, 15, 10, 8]
    print(val)
    # 50

50が削除され、また、max heap の条件を満たすよう木が再構成されました。

heapfy

与えられた配列から、heap 構造を満たす配列を作成します。

削除で定義したpercolate_down を最下層の右側のノードから適用していきます。

ただし、実際のところ、子を持つ一番右型のノードは、 最下層の右側のノード の親であると考えることができるため、len(array) // 2 から percolate_down を適用します。

    def build_heap(self, array):
        # len(array) // 2 で最初の葉でないノードの場所 
        index = len(array) // 2
        self.size = len(array)
        self.heap_list = [0] + array
        while index > 0:
            self.percolate_down(index)
            index -= 1

前回の例に従い、以下に heapify を行います。

if __name__ == "__main__":
  
    heap = Heap()
    array = [10, 20, 15, 12, 40, 25, 18]
    heap.build_heap(array)
    heap.print_heap()
    # [40, 20, 25, 12, 10, 15, 18]

正しい heap 構造になりました。

heap sort

heap sort を行います。

heap sort では、与えられた配列にソートした値を入れていくので、deletepecolate_down を少し変更します。

    def delete_for_sort(self):
        return_value = self.heap_list[1]
        self.heap_list[1] = self.heap_list[self.size]
        self.size -= 1
        # リストの最後の値を削除した値にする
        self.heap_list[self.size+1] = return_value
        index = 1
        self.percolate_down_for_sort(index, self.size)

    def percolate_down_for_sort(self, index, last_index):
        while (index * 2) <= last_index:
            index_minimum_child = self.get_index_minimum_child(index)
            if self.heap_list[index] < self.heap_list[index_minimum_child]:
                tmp = self.heap_list[index]
                self.heap_list[index] = self.heap_list[index_minimum_child]
                self.heap_list[index_minimum_child] = tmp 
            index = index_minimum_child
            
    def heap_sort(self, array):
        self.build_heap(array)
        while heap.size >= 1:
            heap.delete_for_sort()
        return self.heap_list[1:]

前回の例に従い、10, 20, 15, 30, 40 を配列として与えます。

if __name__ == "__main__":

    heap = Heap()
    array = [10, 20, 15, 30, 40]
    sorted_array = heap.heap_sort(array)
    print(sorted_array)
    # [10, 15, 20, 30, 40]

ちゃんとソートされました。

最後に、今回作成したコードです。

class Heap(object):
    def __init__(self):
        # リストの最初を0にすることで、heap の index が 1 始まるように調整
        self.heap_list = [0]
        self.size = 0

    def print_heap(self):
        print(self.heap_list[1:])

    def get_parent(self, index):
        if index <= 1:
            print('2以上の値を入力して下さい。')
            return
        parent = self.heap_list[index // 2]
        return parent
    
    def set_parent(self, index, value):
        self.heap_list[index // 2] = value

    def insert(self, value):
        self.heap_list.append(value)
        self.size += 1
        index = self.size
        while index // 2 > 0:
            if self.heap_list[index] > self.get_parent(index):
                # 親の方が小さいときは入れ替えを行う。
                tmp = self.get_parent(index)
                self.set_parent(index, self.heap_list[index])
                self.heap_list[index] = tmp
            else:
                break
            index = index // 2

    def get_left_child(self, index):
        left_child = self.heap_list[2 * index]
        return left_child
    
    def get_right_child(self, index):
        right_child = self.heap_list[2 * index + 1]
        return right_child

    def get_index_minimum_child(self, index):
        # 右の子が無い場合
        if index * 2 + 1 > self.size:
            return index * 2
        else:
            if self.get_left_child(index) > self.get_right_child(index):
                return index * 2
            else:
                return index * 2 + 1
    
    def percolate_down(self, index):
        while (index * 2) <= self.size:
            index_minimum_child = self.get_index_minimum_child(index)
            if self.heap_list[index] < self.heap_list[index_minimum_child]:
                tmp = self.heap_list[index]
                self.heap_list[index] = self.heap_list[index_minimum_child]
                self.heap_list[index_minimum_child] = tmp 
            index = index_minimum_child

    def delete(self):
        return_value = self.heap_list[1]
        self.heap_list[1] = self.heap_list[self.size]
        self.size -= 1
        # リストの最後を削除
        self.heap_list.pop()
        index = 1
        self.percolate_down(index)
        return return_value
    
    def build_heap(self, array):
        # len(array) // 2 で最初の葉でないノードの場所 
        index = len(array) // 2
        self.size = len(array)
        self.heap_list = [0] + array
        while index > 0:
            self.percolate_down(index)
            index -= 1

    def delete_for_sort(self):
        return_value = self.heap_list[1]
        self.heap_list[1] = self.heap_list[self.size]
        self.size -= 1
        # リストの最後の値を削除した値にする
        self.heap_list[self.size+1] = return_value
        index = 1
        self.percolate_down_for_sort(index, self.size)

    def percolate_down_for_sort(self, index, last_index):
        while (index * 2) <= last_index:
            index_minimum_child = self.get_index_minimum_child(index)
            if self.heap_list[index] < self.heap_list[index_minimum_child]:
                tmp = self.heap_list[index]
                self.heap_list[index] = self.heap_list[index_minimum_child]
                self.heap_list[index_minimum_child] = tmp 
            index = index_minimum_child
            
    def heap_sort(self, array):
        self.build_heap(array)
        while heap.size >= 1:
            heap.delete_for_sort()
        return self.heap_list[1:]
            

if __name__ == "__main__":

    heap = Heap()
    array = [10, 20, 15, 30, 40]
    sorted_array = heap.heap_sort(array)
    print(sorted_array)