2分探索木を Python で実装します。
以前、ほぼ全く同じ内容で記事を書いています。
ノード
2分探索木では、ノードは自身のデータと、0個、1個、2個のいずれかの子ノードを持っており、左側の子ノードは親の値より小さい値、右の側の子ノードは親の値より大きい値を持ちます。
# New added. class Node(object): def __init__(self, data): self.data = data self.left_child = None self.right_child = None
データの挿入
データの挿入は、もしまだ「ルートがない」つまり最初のノードを挿入する場合は、そのノードをルートに設定します。
既にルートが設定されている場合は、データとルートを比較して小さい場合は左、大きい場合は右に進みます。
進んだ方向にノードが無ければ、データをその箇所のノードに設定します。
進んだ方向にノードがあれば、そのノードとデータの大小を比較して、また左右に分かれます。この処理は再帰的に実行します。
挿入箇所を探す際に、このようにデータの大小とノードの大小を比較して左右どちらかの方向に進んでいくことで、木が平衡していれば、\( O (\log N) \) の計算量で挿入が行えます。
class Node(object): def __init__(self, data): self.data = data self.left_child = None self.right_child = None #New added. class BinarySearchTree(object): def __init__(self): self.root = None def insert(self, data): if not self.root: self.root = Node(data) else: self._insert(data, self.root) # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N) def _insert(self, data, node): if data < node.data: if node.left_child: self._insert(data, node.left_child) else: node.left_child = Node(data) else: if node.right_child: self._insert(data, node.right_child) else: node.right_child = Node(data)
最小値・最大値の探索
2分探索木の中で一番左側にあるノードが最小値となります。
左側にノードがある場合は再帰的に最小値の探索を行い、左側にノードがない場合はそのノードが一番左側にあるので、最小値になります。
木が平衡していれば、\( O (\log N) \) の計算量になります。
同様に、 2分探索木の中で一番右側にあるノードが最大値となり、再帰的に処理します。
class Node(object): def __init__(self, data): self.data = data self.left_child = None self.right_child = None class BinarySearchTree(object): def __init__(self): self.root = None def insert(self, data): if not self.root: self.root = Node(data) else: self._insert(data, self.root) # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N) def _insert(self, data, node): if data < node.data: if node.left_child: self._insert(data, node.left_child) else: node.left_child = Node(data) else: if node.right_child: self._insert(data, node.right_child) else: node.right_child = Node(data) # New Added. def get_min(self): if self.root: return self._get_min(self.root) return False def _get_min(self, node): if node.left_child: return self._get_min(node.left_child) return node.data def get_max(self): if self.root: return self._get_max(self.root) return False def _get_max(self, node): if node.right_child: return self._get_max(node.right_child) return node.data
in-order traversal
in-order traversal は、「左側の部分木」→「ルート」→「右側の部分木」という順序で再帰的に木を巡ります。
class Node(object): def __init__(self, data): self.data = data self.left_child = None self.right_child = None class BinarySearchTree(object): def __init__(self): self.root = None def insert(self, data): if not self.root: self.root = Node(data) else: self._insert(data, self.root) # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N) def _insert(self, data, node): if data < node.data: if node.left_child: self._insert(data, node.left_child) else: node.left_child = Node(data) else: if node.right_child: self._insert(data, node.right_child) else: node.right_child = Node(data) def get_min(self): if self.root: return self._get_min(self.root) return False def _get_min(self, node): if node.left_child: return self._get_min(node.left_child) return node.data def get_max(self): if self.root: return self._get_max(self.root) return False def _get_max(self, node): if node.right_child: return self._get_max(node.right_child) return node.data # New Added. def traverse_inorder(self): if self.root: self._traverse_inorder(self.root) else: print('2分木は空です。') def _traverse_inorder(self, node): if node.left_child: self._traverse_inorder(node.left_child) print(node.data) if node.right_child: self._traverse_inorder(node.right_child)
2分探索木を走査できるようになったので、 以下のような木になるようデータを挿入して、木の走査、最大値、最小値を求めてみます。
class Node(object): def __init__(self, data): self.data = data self.left_child = None self.right_child = None class BinarySearchTree(object): def __init__(self): self.root = None def insert(self, data): if not self.root: self.root = Node(data) else: self._insert(data, self.root) # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N) def _insert(self, data, node): if data < node.data: if node.left_child: self._insert(data, node.left_child) else: node.left_child = Node(data) else: if node.right_child: self._insert(data, node.right_child) else: node.right_child = Node(data) def get_min(self): if self.root: return self._get_min(self.root) return False def _get_min(self, node): if node.left_child: return self._get_min(node.left_child) return node.data def get_max(self): if self.root: return self._get_max(self.root) return False def _get_max(self, node): if node.right_child: return self._get_max(node.right_child) return node.data def traverse_inorder(self): if self.root: self._traverse_inorder(self.root) # 改行対策 print() else: print('2分木は空です。') def _traverse_inorder(self, node): if node.left_child: self._traverse_inorder(node.left_child) print(node.data, end=' ') if node.right_child: self._traverse_inorder(node.right_child) if __name__ == '__main__': bst = BinarySearchTree() # 2分木は空です。 bst.traverse_inorder() bst.insert(8) bst.insert(3) bst.insert(10) bst.insert(1) bst.insert(6) bst.insert(14) bst.insert(4) bst.insert(7) bst.insert(13) # ソート順 1 3 4 6 7 8 10 13 14 bst.traverse_inorder() # MIN 1 print('MIN', bst.get_min()) # MAX 14 print('MAX', bst.get_max())
ノードの削除
場合分けが多いのと、ポインタを直接扱うことができないので、コードが少しややこしくなります。
削除を行うヘルパー関数 _remove
は、親ノードの子ノードへのポインタを再設定するために、子ノードを返しています。
平衡木であれば、計算量は\( O (\log N) \) です。
class Node(object): def __init__(self, data): self.data = data self.left_child = None self.right_child = None class BinarySearchTree(object): def __init__(self): self.root = None def insert(self, data): if not self.root: self.root = Node(data) else: self._insert(data, self.root) # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N) def _insert(self, data, node): if data < node.data: if node.left_child: self._insert(data, node.left_child) else: node.left_child = Node(data) else: if node.right_child: self._insert(data, node.right_child) else: node.right_child = Node(data) def get_min(self): if self.root: return self._get_min(self.root) return False def _get_min(self, node): if node.left_child: return self._get_min(node.left_child) return node.data def get_max(self): if self.root: return self._get_max(self.root) return False def _get_max(self, node): if node.right_child: return self._get_max(node.right_child) return node.data def traverse_inorder(self): if self.root: self._traverse_inorder(self.root) # 改行対策 print() else: print('2分木は空です。') def _traverse_inorder(self, node): if node.left_child: self._traverse_inorder(node.left_child) print(node.data, end=' ') if node.right_child: self._traverse_inorder(node.right_child) # new added. def remove(self, data): if self.root: self.root = self._remove(data, self.root) return False def _remove(self, data, node): if not node: return node if data < node.data: node.left_child = self._remove(data, node.left_child) elif data > node.data: node.right_child = self._remove(data, node.right_child) else: # 削除ノードが子を持たない場合は、ノードを削除する。 if not node.right_child and not node.left_child: del node # None を返すことで、親ノードの削除子ノードへのポインタを None に変更 return None # 削除ノードが左の子だけを持つ場合、ノードを削除し、左の子のノードを返す if not node.right_child: temp = node.left_child del node # 左の子のノードを返すことで、親ノードの削除子ノードへのポインタを新しい子ノードに変更 return temp # 削除ノードが右の子だけを持つ場合、左の子だけの場合の逆の操作を行う if not node.left_child: temp = node.right_child del node return temp # 削除ノードが左右の子を持つ場合、ここでは、左側のsubtreeの最大のノードを代わりのノードにすることにする。 # subtreeの最大のノードを取得するヘルパー関数 def _get_max_node(node): if node.right_child: return _get_max_node(node.right_child) return node temp = _get_max_node(node.left_child) node.data = temp.data # 左側のsubtreeから削除ノードと入れ替えたノードを削除 node.left_child = self._remove(temp.data, node.left_child) return node
子のない要素の削除として、「1」と「13」を削除してみます。
if __name__ == '__main__': bst = BinarySearchTree() bst.insert(8) bst.insert(3) bst.insert(10) bst.insert(1) bst.insert(6) bst.insert(14) bst.insert(4) bst.insert(7) bst.insert(13) # 1 3 4 6 7 8 10 13 14 bst.traverse_inorder() # 子のない要素を削除 bst.remove(1) # 3 4 6 7 8 10 13 14 bst.traverse_inorder() bst.remove(13) # 3 4 6 7 8 10 14 bst.traverse_inorder()
子を一つだけ持つ要素の削除として、「10」「14」を削除してみます。
if __name__ == '__main__': bst = BinarySearchTree() bst.insert(8) bst.insert(3) bst.insert(10) bst.insert(1) bst.insert(6) bst.insert(14) bst.insert(4) bst.insert(7) bst.insert(13) # 1 3 4 6 7 8 10 13 14 bst.traverse_inorder() # 子を一つ持つ要素を削除 bst.remove(10) # 1 3 4 6 7 8 13 14 bst.traverse_inorder() bst.remove(14) # 1 3 4 6 7 8 13 bst.traverse_inorder()
子を2つ持つ要素としてルートを削除してみます。
if __name__ == '__main__': bst = BinarySearchTree() bst.insert(8) bst.insert(3) bst.insert(10) bst.insert(1) bst.insert(6) bst.insert(14) bst.insert(4) bst.insert(7) bst.insert(13) # 1 3 4 6 7 8 10 13 14 bst.traverse_inorder() # ルートを削除 bst.remove(8) # 1 3 4 6 7 10 13 14 bst.traverse_inorder()
それぞれ想定通りに動作しています。