2分探索木をPythonで実装します。
二分探索木
二分探索木(にぶんたんさくぎ、英: binary search tree)は、コンピュータプログラムにおいて、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」という制約を持つ二分木である。探索木のうちで最も基本的な木構造である。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
探索の計算量は平均的には O(log2 N) と、通常の2分木と比較して減らせる可能性のある構造となります。
上の定義では、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」となっていますが、ここでは、 「左の子孫の値 < 親の値 < 右の子孫の値」と考えます。
ノードの定義
ノードを定義します。
通常の2分木と変わりません。
def __init__(self, data=None):
class Node:
def __init__(self, data=None):
self.data = data
self.left = None
self.right = None
class Node:
def __init__(self, data=None):
self.data = data
self.left = None
self.right = None
BSTクラス
2分探索木のクラスをBSTクラスとして定義します。
class BST:
def __init__(self):
self.root = None
class BST:
def __init__(self):
self.root = None
挿入
BSTクラスのメソッドとして、データの挿入を定義します。
アルゴリズムは以下のように考えます。
ルートから手順を開始する。
着目しているノードと目的の値を比較する。「目的の値 < 着目しているノード」なら左の子、「着目しているノード ≤ 目的の値」なら右の子が、次の着目ノードとなる。
次の着目ノードが存在しなければ(現在の着目ノードが葉であれば)、次の着目ノードの位置にデータを挿入。存在すれば、次の着目ノードに移って繰り返し。
挿入の計算量は木の高さに比例し、平衡状態であれば O(log N) となる。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
self._insert(data, self.root)
def _insert(self, data, cur_node):
if cur_node.left is None:
cur_node.left = Node(data)
self._insert(data, cur_node.left)
elif data > cur_node.data:
if cur_node.right is None:
cur_node.right = Node(data)
self._insert(data, cur_node.right)
def insert(self, data):
if self.root is None:
self.root = Node(data)
else:
self._insert(data, self.root)
def _insert(self, data, cur_node):
if data < cur_node.data:
if cur_node.left is None:
cur_node.left = Node(data)
else:
self._insert(data, cur_node.left)
elif data > cur_node.data:
if cur_node.right is None:
cur_node.right = Node(data)
else:
self._insert(data, cur_node.right)
else:
print("同じ値があります。")
def insert(self, data):
if self.root is None:
self.root = Node(data)
else:
self._insert(data, self.root)
def _insert(self, data, cur_node):
if data < cur_node.data:
if cur_node.left is None:
cur_node.left = Node(data)
else:
self._insert(data, cur_node.left)
elif data > cur_node.data:
if cur_node.right is None:
cur_node.right = Node(data)
else:
self._insert(data, cur_node.right)
else:
print("同じ値があります。")
データの列挙
BSTクラスのメソッドとして、全データのプリントを定義します。
アルゴリズムは以下のように考えます。
以下のように 再帰呼び出し を使うことで、二分探索木に登録された全データをソートされた順序で列挙できる。
左の子をルートとする部分木に対して、この処理を再帰的に適用する。
親を表示する。
右の子をルートとする部分木に対して、この処理を再帰的に適用する。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
def inorder_print_tree(self):
self._inorder_print_tree(self.root)
def _inorder_print_tree(self, cur_node):
self._inorder_print_tree(cur_node.left)
print(str(cur_node.data))
self._inorder_print_tree(cur_node.right)
def inorder_print_tree(self):
if self.root:
self._inorder_print_tree(self.root)
def _inorder_print_tree(self, cur_node):
if cur_node:
self._inorder_print_tree(cur_node.left)
print(str(cur_node.data))
self._inorder_print_tree(cur_node.right)
def inorder_print_tree(self):
if self.root:
self._inorder_print_tree(self.root)
def _inorder_print_tree(self, cur_node):
if cur_node:
self._inorder_print_tree(cur_node.left)
print(str(cur_node.data))
self._inorder_print_tree(cur_node.right)
データの探索
BSTクラスのメソッドとして、あるデータの探索を定義します。
アルゴリズムは以下のように考えます。
ルートから手順を開始する。
着目しているノードと目的の値を比較する。等しいか、着目ノードが存在しなければ終了。
「目的の値 < 着目しているノード」なら左の子、「着目しているノード < 目的の値」なら右の子へ移って繰り返し。
探索の計算量は木の高さに比例し、平衡状態であれば O(log N) となる。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
再帰的に書きます。
def find_rec(self, data):
is_found = self._find(data, self.root)
def _find(self, data, cur_node):
if data > cur_node.data and cur_node.right:
return self._find(data, cur_node.right)
elif data < cur_node.data and cur_node.left:
return self._find(data, cur_node.left)
if data == cur_node.data:
def find_rec(self, data):
if self.root:
is_found = self._find(data, self.root)
if is_found:
return True
return False
else:
return None
def _find(self, data, cur_node):
if data > cur_node.data and cur_node.right:
return self._find(data, cur_node.right)
elif data < cur_node.data and cur_node.left:
return self._find(data, cur_node.left)
if data == cur_node.data:
return True
def find_rec(self, data):
if self.root:
is_found = self._find(data, self.root)
if is_found:
return True
return False
else:
return None
def _find(self, data, cur_node):
if data > cur_node.data and cur_node.right:
return self._find(data, cur_node.right)
elif data < cur_node.data and cur_node.left:
return self._find(data, cur_node.left)
if data == cur_node.data:
return True
ループを使って書きます。
def find_by_while(self, data):
if data == cur_node.data:
cur_node = cur_node.right
def find_by_while(self, data):
if self.root is None:
return None
cur_node = self.root
while cur_node:
if data == cur_node.data:
return True
if data < cur_node.data:
cur_node = cur_node.left
else:
cur_node = cur_node.right
return False
def find_by_while(self, data):
if self.root is None:
return None
cur_node = self.root
while cur_node:
if data == cur_node.data:
return True
if data < cur_node.data:
cur_node = cur_node.left
else:
cur_node = cur_node.right
return False
データの削除
BSTクラスのメソッドとして、あるデータの削除を定義します。
アルゴリズムは以下のように考えます。
1.ルートから手順を開始する。
2.着目しているノードと目的の値を比較する。「目的の値 < 着目しているノード」なら左の子、「着目しているノード ≤ 目的の値」なら右の子が、次に着目するノードとなる。
3.着目ノードが削除する対象(以下、削除ノード)であり、削除ノードが子どもを持たないなら、そのノードをそのまま削除する。
4.削除ノードが子を一つしかもっていない場合は、削除ノードを削除してその子と置き換える。
5.削除ノードが子を二つ持つ場合
5-1.削除ノードの右の子から最小の値を探索する。
5-2.1 で探索してきたノード(以下、探索ノード)を削除対象のノードと置き換えて、削除対象のノードを削除する。このとき 探索ノードの右の子を探索ノードの元位置に置き換える。
出典: フリー百科事典『ウィキペディア(Wikipedia)』
以下のコードでは、子ノードが2つある場合は、削除対象ノードのデータのみを置き換えて、その後右のsubtreeから置き換えノードを再帰的に削除することで、実装しています。
def _delete_node(self, data, cur_node):
# 削除するデータが現在のノードより小さい場合は左にある
cur_node.left = self._delete_node(data, cur_node.left)
# 削除するデータが現在のノードより大きい場合は右にある
elif data > cur_node.data:
cur_node.right = self._delete_node(data, cur_node.right)
if cur_node.left is None:
elif cur_node.right is None:
temp = self.find_min(cur_node.right)
cur_node.data = temp.data
cur_node.right = self._delete_node(temp.data, cur_node.right)
def _delete_node(self, data, cur_node):
if cur_node is None:
return cur_node
# 削除するデータが現在のノードより小さい場合は左にある
if data < cur_node.data:
cur_node.left = self._delete_node(data, cur_node.left)
# 削除するデータが現在のノードより大きい場合は右にある
elif data > cur_node.data:
cur_node.right = self._delete_node(data, cur_node.right)
# データを削除する
else:
# 子ノードが一つまたはなしの場合
if cur_node.left is None:
temp = cur_node.right
cur_node = None
return temp
elif cur_node.right is None:
temp = cur_node.left
cur_node.left = None
return temp
# 子ノードが2つの場合は右側の一番小さい値
temp = self.find_min(cur_node.right)
cur_node.data = temp.data
cur_node.right = self._delete_node(temp.data, cur_node.right)
return cur_node
def _delete_node(self, data, cur_node):
if cur_node is None:
return cur_node
# 削除するデータが現在のノードより小さい場合は左にある
if data < cur_node.data:
cur_node.left = self._delete_node(data, cur_node.left)
# 削除するデータが現在のノードより大きい場合は右にある
elif data > cur_node.data:
cur_node.right = self._delete_node(data, cur_node.right)
# データを削除する
else:
# 子ノードが一つまたはなしの場合
if cur_node.left is None:
temp = cur_node.right
cur_node = None
return temp
elif cur_node.right is None:
temp = cur_node.left
cur_node.left = None
return temp
# 子ノードが2つの場合は右側の一番小さい値
temp = self.find_min(cur_node.right)
cur_node.data = temp.data
cur_node.right = self._delete_node(temp.data, cur_node.right)
return cur_node
2分探索木コードまとめ
def __init__(self, data=None):
self._insert(data, self.root)
def _insert(self, data, cur_node):
if cur_node.left is None:
cur_node.left = Node(data)
self._insert(data, cur_node.left)
elif data > cur_node.data:
if cur_node.right is None:
cur_node.right = Node(data)
self._insert(data, cur_node.right)
def inorder_print_tree(self):
self._inorder_print_tree(self.root)
def _inorder_print_tree(self, cur_node):
self._inorder_print_tree(cur_node.left)
print(str(cur_node.data))
self._inorder_print_tree(cur_node.right)
def find_min(self, cur_node=None):
def delete_node(self, data):
self.root = self._delete_node(data, self.root)
def _delete_node(self, data, cur_node):
# 削除するデータが現在のノードより小さい場合は左にある
cur_node.left = self._delete_node(data, cur_node.left)
# 削除するデータが現在のノードより大きい場合は右にある
elif data > cur_node.data:
cur_node.right = self._delete_node(data, cur_node.right)
if cur_node.left is None:
elif cur_node.right is None:
temp = self.find_min(cur_node.right)
cur_node.data = temp.data
cur_node.right = self._delete_node(temp.data, cur_node.right)
def find_rec(self, data):
is_found = self._find(data, self.root)
def _find(self, data, cur_node):
if data > cur_node.data and cur_node.right:
return self._find(data, cur_node.right)
elif data < cur_node.data and cur_node.left:
return self._find(data, cur_node.left)
if data == cur_node.data:
def find_by_while(self, data):
if data == cur_node.data:
cur_node = cur_node.right
print(bst.find_by_while(6))
print(bst.find_by_while(12))
class Node:
def __init__(self, data=None):
self.data = data
self.left = None
self.right = None
class BST:
def __init__(self):
self.root = None
def insert(self, data):
if self.root is None:
self.root = Node(data)
else:
self._insert(data, self.root)
def _insert(self, data, cur_node):
if data < cur_node.data:
if cur_node.left is None:
cur_node.left = Node(data)
else:
self._insert(data, cur_node.left)
elif data > cur_node.data:
if cur_node.right is None:
cur_node.right = Node(data)
else:
self._insert(data, cur_node.right)
else:
print("同じ値があります。")
def inorder_print_tree(self):
if self.root:
self._inorder_print_tree(self.root)
def _inorder_print_tree(self, cur_node):
if cur_node:
self._inorder_print_tree(cur_node.left)
print(str(cur_node.data))
self._inorder_print_tree(cur_node.right)
def find_min(self, cur_node=None):
if self.root is None:
return None
while cur_node.left:
cur_node = cur_node.left
return cur_node
def delete_node(self, data):
if self.root is None:
return False
else:
self.root = self._delete_node(data, self.root)
def _delete_node(self, data, cur_node):
if cur_node is None:
return cur_node
# 削除するデータが現在のノードより小さい場合は左にある
if data < cur_node.data:
cur_node.left = self._delete_node(data, cur_node.left)
# 削除するデータが現在のノードより大きい場合は右にある
elif data > cur_node.data:
cur_node.right = self._delete_node(data, cur_node.right)
# データを削除する
else:
# 子ノードが一つまたはなしの場合
if cur_node.left is None:
temp = cur_node.right
cur_node = None
return temp
elif cur_node.right is None:
temp = cur_node.left
cur_node.left = None
return temp
# 子ノードが2つの場合は右側の一番小さい値
temp = self.find_min(cur_node.right)
cur_node.data = temp.data
cur_node.right = self._delete_node(temp.data, cur_node.right)
return cur_node
# 再帰
def find_rec(self, data):
if self.root:
is_found = self._find(data, self.root)
if is_found:
return True
return False
else:
return None
def _find(self, data, cur_node):
if data > cur_node.data and cur_node.right:
return self._find(data, cur_node.right)
elif data < cur_node.data and cur_node.left:
return self._find(data, cur_node.left)
if data == cur_node.data:
return True
# ループ
def find_by_while(self, data):
if self.root is None:
return None
cur_node = self.root
while cur_node:
if data == cur_node.data:
return True
if data < cur_node.data:
cur_node = cur_node.left
else:
cur_node = cur_node.right
return False
bst = BST()
bst.insert(8)
bst.insert(3)
bst.insert(10)
bst.insert(1)
bst.insert(6)
bst.insert(14)
bst.insert(4)
bst.insert(7)
bst.insert(13)
# 8
# /\
# 3 10
# /\ \
# 1 6 14
# /\ /
# 4 7 13
bst.inorder_print_tree()
# 1
# 3
# 4
# 6
# 7
# 8
# 10
# 13
# 14
# # True
print(bst.find_rec(6))
print(bst.find_by_while(6))
# # Flase
print(bst.find_rec(12))
print(bst.find_by_while(12))
bst.delete_node(3)
# 8
# /\
# 4 10
# /\ \
# 1 6 14
# \ /
# 7 13
bst.inorder_print_tree()
# 1
# 4
# 6
# 7
# 8
# 10
# 13
# 14
class Node:
def __init__(self, data=None):
self.data = data
self.left = None
self.right = None
class BST:
def __init__(self):
self.root = None
def insert(self, data):
if self.root is None:
self.root = Node(data)
else:
self._insert(data, self.root)
def _insert(self, data, cur_node):
if data < cur_node.data:
if cur_node.left is None:
cur_node.left = Node(data)
else:
self._insert(data, cur_node.left)
elif data > cur_node.data:
if cur_node.right is None:
cur_node.right = Node(data)
else:
self._insert(data, cur_node.right)
else:
print("同じ値があります。")
def inorder_print_tree(self):
if self.root:
self._inorder_print_tree(self.root)
def _inorder_print_tree(self, cur_node):
if cur_node:
self._inorder_print_tree(cur_node.left)
print(str(cur_node.data))
self._inorder_print_tree(cur_node.right)
def find_min(self, cur_node=None):
if self.root is None:
return None
while cur_node.left:
cur_node = cur_node.left
return cur_node
def delete_node(self, data):
if self.root is None:
return False
else:
self.root = self._delete_node(data, self.root)
def _delete_node(self, data, cur_node):
if cur_node is None:
return cur_node
# 削除するデータが現在のノードより小さい場合は左にある
if data < cur_node.data:
cur_node.left = self._delete_node(data, cur_node.left)
# 削除するデータが現在のノードより大きい場合は右にある
elif data > cur_node.data:
cur_node.right = self._delete_node(data, cur_node.right)
# データを削除する
else:
# 子ノードが一つまたはなしの場合
if cur_node.left is None:
temp = cur_node.right
cur_node = None
return temp
elif cur_node.right is None:
temp = cur_node.left
cur_node.left = None
return temp
# 子ノードが2つの場合は右側の一番小さい値
temp = self.find_min(cur_node.right)
cur_node.data = temp.data
cur_node.right = self._delete_node(temp.data, cur_node.right)
return cur_node
# 再帰
def find_rec(self, data):
if self.root:
is_found = self._find(data, self.root)
if is_found:
return True
return False
else:
return None
def _find(self, data, cur_node):
if data > cur_node.data and cur_node.right:
return self._find(data, cur_node.right)
elif data < cur_node.data and cur_node.left:
return self._find(data, cur_node.left)
if data == cur_node.data:
return True
# ループ
def find_by_while(self, data):
if self.root is None:
return None
cur_node = self.root
while cur_node:
if data == cur_node.data:
return True
if data < cur_node.data:
cur_node = cur_node.left
else:
cur_node = cur_node.right
return False
bst = BST()
bst.insert(8)
bst.insert(3)
bst.insert(10)
bst.insert(1)
bst.insert(6)
bst.insert(14)
bst.insert(4)
bst.insert(7)
bst.insert(13)
# 8
# /\
# 3 10
# /\ \
# 1 6 14
# /\ /
# 4 7 13
bst.inorder_print_tree()
# 1
# 3
# 4
# 6
# 7
# 8
# 10
# 13
# 14
# # True
print(bst.find_rec(6))
print(bst.find_by_while(6))
# # Flase
print(bst.find_rec(12))
print(bst.find_by_while(12))
bst.delete_node(3)
# 8
# /\
# 4 10
# /\ \
# 1 6 14
# \ /
# 7 13
bst.inorder_print_tree()
# 1
# 4
# 6
# 7
# 8
# 10
# 13
# 14