2分探索木をPythonで実装します。
二分探索木
二分探索木(にぶんたんさくぎ、英: binary search tree)は、コンピュータプログラムにおいて、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」という制約を持つ二分木である。探索木のうちで最も基本的な木構造である。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
探索の計算量は平均的には O(log2 N) と、通常の2分木と比較して減らせる可能性のある構造となります。
上の定義では、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」となっていますが、ここでは、 「左の子孫の値 < 親の値 < 右の子孫の値」と考えます。
ノードの定義
ノードを定義します。
通常の2分木と変わりません。
class Node: def __init__(self, data=None): self.data = data self.left = None self.right = None
BSTクラス
2分探索木のクラスをBSTクラスとして定義します。
class BST: def __init__(self): self.root = None
挿入
BSTクラスのメソッドとして、データの挿入を定義します。
アルゴリズムは以下のように考えます。
ルートから手順を開始する。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
着目しているノードと目的の値を比較する。「目的の値 < 着目しているノード」なら左の子、「着目しているノード ≤ 目的の値」なら右の子が、次の着目ノードとなる。
次の着目ノードが存在しなければ(現在の着目ノードが葉であれば)、次の着目ノードの位置にデータを挿入。存在すれば、次の着目ノードに移って繰り返し。
挿入の計算量は木の高さに比例し、平衡状態であれば O(log N) となる。
def insert(self, data): if self.root is None: self.root = Node(data) else: self._insert(data, self.root) def _insert(self, data, cur_node): if data < cur_node.data: if cur_node.left is None: cur_node.left = Node(data) else: self._insert(data, cur_node.left) elif data > cur_node.data: if cur_node.right is None: cur_node.right = Node(data) else: self._insert(data, cur_node.right) else: print("同じ値があります。")
データの列挙
BSTクラスのメソッドとして、全データのプリントを定義します。
アルゴリズムは以下のように考えます。
以下のように 再帰呼び出し を使うことで、二分探索木に登録された全データをソートされた順序で列挙できる。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
左の子をルートとする部分木に対して、この処理を再帰的に適用する。
親を表示する。
右の子をルートとする部分木に対して、この処理を再帰的に適用する。
def inorder_print_tree(self): if self.root: self._inorder_print_tree(self.root) def _inorder_print_tree(self, cur_node): if cur_node: self._inorder_print_tree(cur_node.left) print(str(cur_node.data)) self._inorder_print_tree(cur_node.right)
データの探索
BSTクラスのメソッドとして、あるデータの探索を定義します。
アルゴリズムは以下のように考えます。
ルートから手順を開始する。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
着目しているノードと目的の値を比較する。等しいか、着目ノードが存在しなければ終了。
「目的の値 < 着目しているノード」なら左の子、「着目しているノード < 目的の値」なら右の子へ移って繰り返し。
探索の計算量は木の高さに比例し、平衡状態であれば O(log N) となる。
再帰的に書きます。
def find_rec(self, data): if self.root: is_found = self._find(data, self.root) if is_found: return True return False else: return None def _find(self, data, cur_node): if data > cur_node.data and cur_node.right: return self._find(data, cur_node.right) elif data < cur_node.data and cur_node.left: return self._find(data, cur_node.left) if data == cur_node.data: return True
ループを使って書きます。
def find_by_while(self, data): if self.root is None: return None cur_node = self.root while cur_node: if data == cur_node.data: return True if data < cur_node.data: cur_node = cur_node.left else: cur_node = cur_node.right return False
データの削除
BSTクラスのメソッドとして、あるデータの削除を定義します。
アルゴリズムは以下のように考えます。
1.ルートから手順を開始する。
2.着目しているノードと目的の値を比較する。「目的の値 < 着目しているノード」なら左の子、「着目しているノード ≤ 目的の値」なら右の子が、次に着目するノードとなる。
3.着目ノードが削除する対象(以下、削除ノード)であり、削除ノードが子どもを持たないなら、そのノードをそのまま削除する。
4.削除ノードが子を一つしかもっていない場合は、削除ノードを削除してその子と置き換える。
5.削除ノードが子を二つ持つ場合5-1.削除ノードの右の子から最小の値を探索する。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
5-2.1 で探索してきたノード(以下、探索ノード)を削除対象のノードと置き換えて、削除対象のノードを削除する。このとき 探索ノードの右の子を探索ノードの元位置に置き換える。
以下のコードでは、子ノードが2つある場合は、削除対象ノードのデータのみを置き換えて、その後右のsubtreeから置き換えノードを再帰的に削除することで、実装しています。
def _delete_node(self, data, cur_node): if cur_node is None: return cur_node # 削除するデータが現在のノードより小さい場合は左にある if data < cur_node.data: cur_node.left = self._delete_node(data, cur_node.left) # 削除するデータが現在のノードより大きい場合は右にある elif data > cur_node.data: cur_node.right = self._delete_node(data, cur_node.right) # データを削除する else: # 子ノードが一つまたはなしの場合 if cur_node.left is None: temp = cur_node.right cur_node = None return temp elif cur_node.right is None: temp = cur_node.left cur_node.left = None return temp # 子ノードが2つの場合は右側の一番小さい値 temp = self.find_min(cur_node.right) cur_node.data = temp.data cur_node.right = self._delete_node(temp.data, cur_node.right) return cur_node
2分探索木コードまとめ
class Node: def __init__(self, data=None): self.data = data self.left = None self.right = None class BST: def __init__(self): self.root = None def insert(self, data): if self.root is None: self.root = Node(data) else: self._insert(data, self.root) def _insert(self, data, cur_node): if data < cur_node.data: if cur_node.left is None: cur_node.left = Node(data) else: self._insert(data, cur_node.left) elif data > cur_node.data: if cur_node.right is None: cur_node.right = Node(data) else: self._insert(data, cur_node.right) else: print("同じ値があります。") def inorder_print_tree(self): if self.root: self._inorder_print_tree(self.root) def _inorder_print_tree(self, cur_node): if cur_node: self._inorder_print_tree(cur_node.left) print(str(cur_node.data)) self._inorder_print_tree(cur_node.right) def find_min(self, cur_node=None): if self.root is None: return None while cur_node.left: cur_node = cur_node.left return cur_node def delete_node(self, data): if self.root is None: return False else: self.root = self._delete_node(data, self.root) def _delete_node(self, data, cur_node): if cur_node is None: return cur_node # 削除するデータが現在のノードより小さい場合は左にある if data < cur_node.data: cur_node.left = self._delete_node(data, cur_node.left) # 削除するデータが現在のノードより大きい場合は右にある elif data > cur_node.data: cur_node.right = self._delete_node(data, cur_node.right) # データを削除する else: # 子ノードが一つまたはなしの場合 if cur_node.left is None: temp = cur_node.right cur_node = None return temp elif cur_node.right is None: temp = cur_node.left cur_node.left = None return temp # 子ノードが2つの場合は右側の一番小さい値 temp = self.find_min(cur_node.right) cur_node.data = temp.data cur_node.right = self._delete_node(temp.data, cur_node.right) return cur_node # 再帰 def find_rec(self, data): if self.root: is_found = self._find(data, self.root) if is_found: return True return False else: return None def _find(self, data, cur_node): if data > cur_node.data and cur_node.right: return self._find(data, cur_node.right) elif data < cur_node.data and cur_node.left: return self._find(data, cur_node.left) if data == cur_node.data: return True # ループ def find_by_while(self, data): if self.root is None: return None cur_node = self.root while cur_node: if data == cur_node.data: return True if data < cur_node.data: cur_node = cur_node.left else: cur_node = cur_node.right return False bst = BST() bst.insert(8) bst.insert(3) bst.insert(10) bst.insert(1) bst.insert(6) bst.insert(14) bst.insert(4) bst.insert(7) bst.insert(13) # 8 # /\ # 3 10 # /\ \ # 1 6 14 # /\ / # 4 7 13 bst.inorder_print_tree() # 1 # 3 # 4 # 6 # 7 # 8 # 10 # 13 # 14 # # True print(bst.find_rec(6)) print(bst.find_by_while(6)) # # Flase print(bst.find_rec(12)) print(bst.find_by_while(12)) bst.delete_node(3) # 8 # /\ # 4 10 # /\ \ # 1 6 14 # \ / # 7 13 bst.inorder_print_tree() # 1 # 4 # 6 # 7 # 8 # 10 # 13 # 14