前回からの続きです。
Python では、heapq
によりheapがモジュールとして提供されているので、普段はこちらを使います。
ここでは、学習のため、前回の内容に沿って、Python で heap を実装します。
max heap を考えます。
データ構造
heap のデータ構造ですが、配列はリストで表します。
また、リストの初期状態を0にすることで、heap の index が1から始まるようにします。
これにより、前回のindexが1から始まることと整合性を保ちます。
class Heap(object): def __init__(self): # リストの最初を0にすることで、heap の index が 1 始まるように調整 self.heap_list = [0] self.size = 0 def print_heap(self): print(self.heap_list[1:])
ノードの挿入
ノードを配列の一番最後に挿入し、親の値が自分より小さい場合は、入れ替えを行います。
def get_parent(self, index): if index <= 1: print('2以上の値を入力して下さい。') return parent = self.heap_list[index // 2] return parent def set_parent(self, index, value): self.heap_list[index // 2] = value def insert(self, value): self.heap_list.append(value) self.size += 1 index = self.size while index // 2 > 0: if self.heap_list[index] > self.get_parent(index): # 親の方が小さいときは入れ替えを行う。 tmp = self.get_parent(index) self.set_parent(index, self.heap_list[index]) self.heap_list[index] = tmp else: break index = index // 2
前回の挿入の例に従い、以下に60を挿入します。
if __name__ == "__main__": heap = Heap() heap.insert(50) heap.insert(30) heap.insert(20) heap.insert(15) heap.insert(10) heap.insert(8) heap.insert(16) heap.print_heap() # [50, 30, 20, 15, 10, 8, 16] heap.insert(60) heap.print_heap() # [60, 50, 20, 30, 10, 8, 16, 15]
60の挿入が max heap の条件を満たすようにちゃんと行われました。
ノードの削除
リストの一番上の値を削除し、リストの一番最後の値をリストの一番最初に持ってきます。
その後、左の子と右の子と値を比較し、大きいほうの値の子と自身の値を比較し、子の方が大きい場合は、入れ替えを行います。この入れ替えは、percolate_down
関数として定義しおきます。
def get_left_child(self, index): left_child = self.heap_list[2 * index] return left_child def get_right_child(self, index): right_child = self.heap_list[2 * index + 1] return right_child def get_index_minimum_child(self, index): # 右の子が無い場合 if index * 2 + 1 > self.size: return index * 2 else: if self.get_left_child(index) > self.get_right_child(index): return index * 2 else: return index * 2 + 1 def percolate_down(self, index): while (index * 2) <= self.size: index_minimum_child = self.get_index_minimum_child(index) if self.heap_list[index] < self.heap_list[index_minimum_child]: tmp = self.heap_list[index] self.heap_list[index] = self.heap_list[index_minimum_child] self.heap_list[index_minimum_child] = tmp index = index_minimum_child def delete(self): return_value = self.heap_list[1] self.heap_list[1] = self.heap_list[self.size] self.size -= 1 # リストの最後を削除 self.heap_list.pop() index = 1 self.percolate_down(index) return return_value
前回の例に従い、以下から50を削除します。
if __name__ == "__main__": heap = Heap() heap.insert(50) heap.insert(30) heap.insert(20) heap.insert(15) heap.insert(10) heap.insert(8) heap.insert(16) heap.print_heap() # [50, 30, 20, 15, 10, 8, 16] val = heap.delete() heap.print_heap() # [30, 16, 20, 15, 10, 8] print(val) # 50
50が削除され、また、max heap の条件を満たすよう木が再構成されました。
heapfy
与えられた配列から、heap 構造を満たす配列を作成します。
削除で定義したpercolate_down
を最下層の右側のノードから適用していきます。
ただし、実際のところ、子を持つ一番右型のノードは、 最下層の右側のノード の親であると考えることができるため、len(array) // 2
から percolate_down
を適用します。
def build_heap(self, array): # len(array) // 2 で最初の葉でないノードの場所 index = len(array) // 2 self.size = len(array) self.heap_list = [0] + array while index > 0: self.percolate_down(index) index -= 1
前回の例に従い、以下に heapify を行います。
if __name__ == "__main__": heap = Heap() array = [10, 20, 15, 12, 40, 25, 18] heap.build_heap(array) heap.print_heap() # [40, 20, 25, 12, 10, 15, 18]
正しい heap 構造になりました。
heap sort
heap sort を行います。
heap sort では、与えられた配列にソートした値を入れていくので、delete
と pecolate_down
を少し変更します。
def delete_for_sort(self): return_value = self.heap_list[1] self.heap_list[1] = self.heap_list[self.size] self.size -= 1 # リストの最後の値を削除した値にする self.heap_list[self.size+1] = return_value index = 1 self.percolate_down_for_sort(index, self.size) def percolate_down_for_sort(self, index, last_index): while (index * 2) <= last_index: index_minimum_child = self.get_index_minimum_child(index) if self.heap_list[index] < self.heap_list[index_minimum_child]: tmp = self.heap_list[index] self.heap_list[index] = self.heap_list[index_minimum_child] self.heap_list[index_minimum_child] = tmp index = index_minimum_child def heap_sort(self, array): self.build_heap(array) while heap.size >= 1: heap.delete_for_sort() return self.heap_list[1:]
前回の例に従い、10, 20, 15, 30, 40
を配列として与えます。
if __name__ == "__main__": heap = Heap() array = [10, 20, 15, 30, 40] sorted_array = heap.heap_sort(array) print(sorted_array) # [10, 15, 20, 30, 40]
ちゃんとソートされました。
最後に、今回作成したコードです。
class Heap(object): def __init__(self): # リストの最初を0にすることで、heap の index が 1 始まるように調整 self.heap_list = [0] self.size = 0 def print_heap(self): print(self.heap_list[1:]) def get_parent(self, index): if index <= 1: print('2以上の値を入力して下さい。') return parent = self.heap_list[index // 2] return parent def set_parent(self, index, value): self.heap_list[index // 2] = value def insert(self, value): self.heap_list.append(value) self.size += 1 index = self.size while index // 2 > 0: if self.heap_list[index] > self.get_parent(index): # 親の方が小さいときは入れ替えを行う。 tmp = self.get_parent(index) self.set_parent(index, self.heap_list[index]) self.heap_list[index] = tmp else: break index = index // 2 def get_left_child(self, index): left_child = self.heap_list[2 * index] return left_child def get_right_child(self, index): right_child = self.heap_list[2 * index + 1] return right_child def get_index_minimum_child(self, index): # 右の子が無い場合 if index * 2 + 1 > self.size: return index * 2 else: if self.get_left_child(index) > self.get_right_child(index): return index * 2 else: return index * 2 + 1 def percolate_down(self, index): while (index * 2) <= self.size: index_minimum_child = self.get_index_minimum_child(index) if self.heap_list[index] < self.heap_list[index_minimum_child]: tmp = self.heap_list[index] self.heap_list[index] = self.heap_list[index_minimum_child] self.heap_list[index_minimum_child] = tmp index = index_minimum_child def delete(self): return_value = self.heap_list[1] self.heap_list[1] = self.heap_list[self.size] self.size -= 1 # リストの最後を削除 self.heap_list.pop() index = 1 self.percolate_down(index) return return_value def build_heap(self, array): # len(array) // 2 で最初の葉でないノードの場所 index = len(array) // 2 self.size = len(array) self.heap_list = [0] + array while index > 0: self.percolate_down(index) index -= 1 def delete_for_sort(self): return_value = self.heap_list[1] self.heap_list[1] = self.heap_list[self.size] self.size -= 1 # リストの最後の値を削除した値にする self.heap_list[self.size+1] = return_value index = 1 self.percolate_down_for_sort(index, self.size) def percolate_down_for_sort(self, index, last_index): while (index * 2) <= last_index: index_minimum_child = self.get_index_minimum_child(index) if self.heap_list[index] < self.heap_list[index_minimum_child]: tmp = self.heap_list[index] self.heap_list[index] = self.heap_list[index_minimum_child] self.heap_list[index_minimum_child] = tmp index = index_minimum_child def heap_sort(self, array): self.build_heap(array) while heap.size >= 1: heap.delete_for_sort() return self.heap_list[1:] if __name__ == "__main__": heap = Heap() array = [10, 20, 15, 30, 40] sorted_array = heap.heap_sort(array) print(sorted_array)