2分探索木を Python で実装します。
以前、ほぼ全く同じ内容で記事を書いています。
ノード
2分探索木では、ノードは自身のデータと、0個、1個、2個のいずれかの子ノードを持っており、左側の子ノードは親の値より小さい値、右の側の子ノードは親の値より大きい値を持ちます。
# New added.
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
データの挿入
データの挿入は、もしまだ「ルートがない」つまり最初のノードを挿入する場合は、そのノードをルートに設定します。
既にルートが設定されている場合は、データとルートを比較して小さい場合は左、大きい場合は右に進みます。
進んだ方向にノードが無ければ、データをその箇所のノードに設定します。
進んだ方向にノードがあれば、そのノードとデータの大小を比較して、また左右に分かれます。この処理は再帰的に実行します。
挿入箇所を探す際に、このようにデータの大小とノードの大小を比較して左右どちらかの方向に進んでいくことで、木が平衡していれば、\( O (\log N) \) の計算量で挿入が行えます。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
#New added.
class BinarySearchTree(object):
def __init__(self):
self.root = None
def insert(self, data):
if not self.root:
self.root = Node(data)
else:
self._insert(data, self.root)
# 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)
def _insert(self, data, node):
if data < node.data:
if node.left_child:
self._insert(data, node.left_child)
else:
node.left_child = Node(data)
else:
if node.right_child:
self._insert(data, node.right_child)
else:
node.right_child = Node(data)
最小値・最大値の探索
2分探索木の中で一番左側にあるノードが最小値となります。
左側にノードがある場合は再帰的に最小値の探索を行い、左側にノードがない場合はそのノードが一番左側にあるので、最小値になります。
木が平衡していれば、\( O (\log N) \) の計算量になります。
同様に、 2分探索木の中で一番右側にあるノードが最大値となり、再帰的に処理します。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
class BinarySearchTree(object):
def __init__(self):
self.root = None
def insert(self, data):
if not self.root:
self.root = Node(data)
else:
self._insert(data, self.root)
# 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)
def _insert(self, data, node):
if data < node.data:
if node.left_child:
self._insert(data, node.left_child)
else:
node.left_child = Node(data)
else:
if node.right_child:
self._insert(data, node.right_child)
else:
node.right_child = Node(data)
# New Added.
def get_min(self):
if self.root:
return self._get_min(self.root)
return False
def _get_min(self, node):
if node.left_child:
return self._get_min(node.left_child)
return node.data
def get_max(self):
if self.root:
return self._get_max(self.root)
return False
def _get_max(self, node):
if node.right_child:
return self._get_max(node.right_child)
return node.data
in-order traversal
in-order traversal は、「左側の部分木」→「ルート」→「右側の部分木」という順序で再帰的に木を巡ります。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
class BinarySearchTree(object):
def __init__(self):
self.root = None
def insert(self, data):
if not self.root:
self.root = Node(data)
else:
self._insert(data, self.root)
# 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)
def _insert(self, data, node):
if data < node.data:
if node.left_child:
self._insert(data, node.left_child)
else:
node.left_child = Node(data)
else:
if node.right_child:
self._insert(data, node.right_child)
else:
node.right_child = Node(data)
def get_min(self):
if self.root:
return self._get_min(self.root)
return False
def _get_min(self, node):
if node.left_child:
return self._get_min(node.left_child)
return node.data
def get_max(self):
if self.root:
return self._get_max(self.root)
return False
def _get_max(self, node):
if node.right_child:
return self._get_max(node.right_child)
return node.data
# New Added.
def traverse_inorder(self):
if self.root:
self._traverse_inorder(self.root)
else:
print('2分木は空です。')
def _traverse_inorder(self, node):
if node.left_child:
self._traverse_inorder(node.left_child)
print(node.data)
if node.right_child:
self._traverse_inorder(node.right_child)
2分探索木を走査できるようになったので、 以下のような木になるようデータを挿入して、木の走査、最大値、最小値を求めてみます。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
class BinarySearchTree(object):
def __init__(self):
self.root = None
def insert(self, data):
if not self.root:
self.root = Node(data)
else:
self._insert(data, self.root)
# 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)
def _insert(self, data, node):
if data < node.data:
if node.left_child:
self._insert(data, node.left_child)
else:
node.left_child = Node(data)
else:
if node.right_child:
self._insert(data, node.right_child)
else:
node.right_child = Node(data)
def get_min(self):
if self.root:
return self._get_min(self.root)
return False
def _get_min(self, node):
if node.left_child:
return self._get_min(node.left_child)
return node.data
def get_max(self):
if self.root:
return self._get_max(self.root)
return False
def _get_max(self, node):
if node.right_child:
return self._get_max(node.right_child)
return node.data
def traverse_inorder(self):
if self.root:
self._traverse_inorder(self.root)
# 改行対策
print()
else:
print('2分木は空です。')
def _traverse_inorder(self, node):
if node.left_child:
self._traverse_inorder(node.left_child)
print(node.data, end=' ')
if node.right_child:
self._traverse_inorder(node.right_child)
if __name__ == '__main__':
bst = BinarySearchTree()
# 2分木は空です。
bst.traverse_inorder()
bst.insert(8)
bst.insert(3)
bst.insert(10)
bst.insert(1)
bst.insert(6)
bst.insert(14)
bst.insert(4)
bst.insert(7)
bst.insert(13)
# ソート順 1 3 4 6 7 8 10 13 14
bst.traverse_inorder()
# MIN 1
print('MIN', bst.get_min())
# MAX 14
print('MAX', bst.get_max())
ノードの削除
場合分けが多いのと、ポインタを直接扱うことができないので、コードが少しややこしくなります。
削除を行うヘルパー関数 _remove は、親ノードの子ノードへのポインタを再設定するために、子ノードを返しています。
平衡木であれば、計算量は\( O (\log N) \) です。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
class BinarySearchTree(object):
def __init__(self):
self.root = None
def insert(self, data):
if not self.root:
self.root = Node(data)
else:
self._insert(data, self.root)
# 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)
def _insert(self, data, node):
if data < node.data:
if node.left_child:
self._insert(data, node.left_child)
else:
node.left_child = Node(data)
else:
if node.right_child:
self._insert(data, node.right_child)
else:
node.right_child = Node(data)
def get_min(self):
if self.root:
return self._get_min(self.root)
return False
def _get_min(self, node):
if node.left_child:
return self._get_min(node.left_child)
return node.data
def get_max(self):
if self.root:
return self._get_max(self.root)
return False
def _get_max(self, node):
if node.right_child:
return self._get_max(node.right_child)
return node.data
def traverse_inorder(self):
if self.root:
self._traverse_inorder(self.root)
# 改行対策
print()
else:
print('2分木は空です。')
def _traverse_inorder(self, node):
if node.left_child:
self._traverse_inorder(node.left_child)
print(node.data, end=' ')
if node.right_child:
self._traverse_inorder(node.right_child)
# new added.
def remove(self, data):
if self.root:
self.root = self._remove(data, self.root)
return False
def _remove(self, data, node):
if not node:
return node
if data < node.data:
node.left_child = self._remove(data, node.left_child)
elif data > node.data:
node.right_child = self._remove(data, node.right_child)
else:
# 削除ノードが子を持たない場合は、ノードを削除する。
if not node.right_child and not node.left_child:
del node
# None を返すことで、親ノードの削除子ノードへのポインタを None に変更
return None
# 削除ノードが左の子だけを持つ場合、ノードを削除し、左の子のノードを返す
if not node.right_child:
temp = node.left_child
del node
# 左の子のノードを返すことで、親ノードの削除子ノードへのポインタを新しい子ノードに変更
return temp
# 削除ノードが右の子だけを持つ場合、左の子だけの場合の逆の操作を行う
if not node.left_child:
temp = node.right_child
del node
return temp
# 削除ノードが左右の子を持つ場合、ここでは、左側のsubtreeの最大のノードを代わりのノードにすることにする。
# subtreeの最大のノードを取得するヘルパー関数
def _get_max_node(node):
if node.right_child:
return _get_max_node(node.right_child)
return node
temp = _get_max_node(node.left_child)
node.data = temp.data
# 左側のsubtreeから削除ノードと入れ替えたノードを削除
node.left_child = self._remove(temp.data, node.left_child)
return node
子のない要素の削除として、「1」と「13」を削除してみます。
if __name__ == '__main__':
bst = BinarySearchTree()
bst.insert(8)
bst.insert(3)
bst.insert(10)
bst.insert(1)
bst.insert(6)
bst.insert(14)
bst.insert(4)
bst.insert(7)
bst.insert(13)
# 1 3 4 6 7 8 10 13 14
bst.traverse_inorder()
# 子のない要素を削除
bst.remove(1)
# 3 4 6 7 8 10 13 14
bst.traverse_inorder()
bst.remove(13)
# 3 4 6 7 8 10 14
bst.traverse_inorder()
子を一つだけ持つ要素の削除として、「10」「14」を削除してみます。
if __name__ == '__main__':
bst = BinarySearchTree()
bst.insert(8)
bst.insert(3)
bst.insert(10)
bst.insert(1)
bst.insert(6)
bst.insert(14)
bst.insert(4)
bst.insert(7)
bst.insert(13)
# 1 3 4 6 7 8 10 13 14
bst.traverse_inorder()
# 子を一つ持つ要素を削除
bst.remove(10)
# 1 3 4 6 7 8 13 14
bst.traverse_inorder()
bst.remove(14)
# 1 3 4 6 7 8 13
bst.traverse_inorder()
子を2つ持つ要素としてルートを削除してみます。
if __name__ == '__main__':
bst = BinarySearchTree()
bst.insert(8)
bst.insert(3)
bst.insert(10)
bst.insert(1)
bst.insert(6)
bst.insert(14)
bst.insert(4)
bst.insert(7)
bst.insert(13)
# 1 3 4 6 7 8 10 13 14
bst.traverse_inorder()
# ルートを削除
bst.remove(8)
# 1 3 4 6 7 10 13 14
bst.traverse_inorder()
それぞれ想定通りに動作しています。