[Python] 二分探索木/ Binary Search Tree

2分探索木をPythonで実装します。

二分探索木

二分探索木(にぶんたんさくぎ、: binary search tree)は、コンピュータプログラムにおいて、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」という制約を持つ二分木である。探索木のうちで最も基本的な木構造である。

出典: フリー百科事典『ウィキペディア(Wikipedia)』

探索の計算量は平均的には O(log2 N)  と、通常の2分木と比較して減らせる可能性のある構造となります。

上の定義では、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」となっていますが、ここでは、 「左の子孫の値 < 親の値 < 右の子孫の値」と考えます。

ノードの定義

ノードを定義します。

通常の2分木と変わりません。

class Node:
def __init__(self, data=None):
self.data = data
self.left = None
self.right = None
class Node: def __init__(self, data=None): self.data = data self.left = None self.right = None
class Node:
    def __init__(self, data=None):
        self.data = data
        self.left = None
        self.right = None

BSTクラス

2分探索木のクラスをBSTクラスとして定義します。

class BST:
def __init__(self):
self.root = None
class BST: def __init__(self): self.root = None
class BST:
    def __init__(self):
        self.root = None

挿入

BSTクラスのメソッドとして、データの挿入を定義します。

アルゴリズムは以下のように考えます。

ルートから手順を開始する。
着目しているノードと目的の値を比較する。「目的の値 < 着目しているノード」なら左の子、「着目しているノード ≤ 目的の値」なら右の子が、次の着目ノードとなる。
次の着目ノードが存在しなければ(現在の着目ノードが葉であれば)、次の着目ノードの位置にデータを挿入。存在すれば、次の着目ノードに移って繰り返し。
挿入の計算量は木の高さに比例し、平衡状態であれば O(log N) となる。

出典: フリー百科事典『ウィキペディア(Wikipedia)』
def insert(self, data):
if self.root is None:
self.root = Node(data)
else:
self._insert(data, self.root)
def _insert(self, data, cur_node):
if data < cur_node.data:
if cur_node.left is None:
cur_node.left = Node(data)
else:
self._insert(data, cur_node.left)
elif data > cur_node.data:
if cur_node.right is None:
cur_node.right = Node(data)
else:
self._insert(data, cur_node.right)
else:
print("同じ値があります。")
def insert(self, data): if self.root is None: self.root = Node(data) else: self._insert(data, self.root) def _insert(self, data, cur_node): if data < cur_node.data: if cur_node.left is None: cur_node.left = Node(data) else: self._insert(data, cur_node.left) elif data > cur_node.data: if cur_node.right is None: cur_node.right = Node(data) else: self._insert(data, cur_node.right) else: print("同じ値があります。")
    def insert(self, data):
        if self.root is None:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    def _insert(self, data, cur_node):
        if data < cur_node.data:
            if cur_node.left is None:
                cur_node.left = Node(data)
            else:
                self._insert(data, cur_node.left)
        elif data > cur_node.data:
            if cur_node.right is None:
                cur_node.right = Node(data)
            else:
                self._insert(data, cur_node.right)
        else:
            print("同じ値があります。")

データの列挙

BSTクラスのメソッドとして、全データのプリントを定義します。

アルゴリズムは以下のように考えます。

以下のように 再帰呼び出し を使うことで、二分探索木に登録された全データをソートされた順序で列挙できる。
左の子をルートとする部分木に対して、この処理を再帰的に適用する。
親を表示する。
右の子をルートとする部分木に対して、この処理を再帰的に適用する。

出典: フリー百科事典『ウィキペディア(Wikipedia)』
def inorder_print_tree(self):
if self.root:
self._inorder_print_tree(self.root)
def _inorder_print_tree(self, cur_node):
if cur_node:
self._inorder_print_tree(cur_node.left)
print(str(cur_node.data))
self._inorder_print_tree(cur_node.right)
def inorder_print_tree(self): if self.root: self._inorder_print_tree(self.root) def _inorder_print_tree(self, cur_node): if cur_node: self._inorder_print_tree(cur_node.left) print(str(cur_node.data)) self._inorder_print_tree(cur_node.right)
    def inorder_print_tree(self):
        if self.root:
            self._inorder_print_tree(self.root)

    def _inorder_print_tree(self, cur_node):
        if cur_node:
            self._inorder_print_tree(cur_node.left)
            print(str(cur_node.data))
            self._inorder_print_tree(cur_node.right)

データの探索

BSTクラスのメソッドとして、あるデータの探索を定義します。

アルゴリズムは以下のように考えます。

ルートから手順を開始する。
着目しているノードと目的の値を比較する。等しいか、着目ノードが存在しなければ終了。
「目的の値 < 着目しているノード」なら左の子、「着目しているノード < 目的の値」なら右の子へ移って繰り返し。
探索の計算量は木の高さに比例し、平衡状態であれば O(log N) となる。

出典: フリー百科事典『ウィキペディア(Wikipedia)』

再帰的に書きます。

def find_rec(self, data):
if self.root:
is_found = self._find(data, self.root)
if is_found:
return True
return False
else:
return None
def _find(self, data, cur_node):
if data > cur_node.data and cur_node.right:
return self._find(data, cur_node.right)
elif data < cur_node.data and cur_node.left:
return self._find(data, cur_node.left)
if data == cur_node.data:
return True
def find_rec(self, data): if self.root: is_found = self._find(data, self.root) if is_found: return True return False else: return None def _find(self, data, cur_node): if data > cur_node.data and cur_node.right: return self._find(data, cur_node.right) elif data < cur_node.data and cur_node.left: return self._find(data, cur_node.left) if data == cur_node.data: return True
    def find_rec(self, data):
        if self.root:
            is_found = self._find(data, self.root)
            if is_found:
                return True
            return False
        else:
            return None

    def _find(self, data, cur_node):
        if data > cur_node.data and cur_node.right:
            return self._find(data, cur_node.right)
        elif data < cur_node.data and cur_node.left:
            return self._find(data, cur_node.left)
        if data == cur_node.data:
            return True

ループを使って書きます。

def find_by_while(self, data):
if self.root is None:
return None
cur_node = self.root
while cur_node:
if data == cur_node.data:
return True
if data < cur_node.data:
cur_node = cur_node.left
else:
cur_node = cur_node.right
return False
def find_by_while(self, data): if self.root is None: return None cur_node = self.root while cur_node: if data == cur_node.data: return True if data < cur_node.data: cur_node = cur_node.left else: cur_node = cur_node.right return False
def find_by_while(self, data):
        if self.root is None:
            return None
        cur_node = self.root
        while cur_node:
            if data == cur_node.data:
                return True
            if data < cur_node.data:
                cur_node = cur_node.left
            else:
                cur_node = cur_node.right
        return False

データの削除

BSTクラスのメソッドとして、あるデータの削除を定義します。

アルゴリズムは以下のように考えます。

1.ルートから手順を開始する。
2.着目しているノードと目的の値を比較する。「目的の値 < 着目しているノード」なら左の子、「着目しているノード ≤ 目的の値」なら右の子が、次に着目するノードとなる。
3.着目ノードが削除する対象(以下、削除ノード)であり、削除ノードが子どもを持たないなら、そのノードをそのまま削除する。
4.削除ノードが子を一つしかもっていない場合は、削除ノードを削除してその子と置き換える。
5.削除ノードが子を二つ持つ場合

5-1.削除ノードの右の子から最小の値を探索する。
5-2.1 で探索してきたノード(以下、探索ノード)を削除対象のノードと置き換えて、削除対象のノードを削除する。このとき 探索ノードの右の子を探索ノードの元位置に置き換える。

出典: フリー百科事典『ウィキペディア(Wikipedia)』

以下のコードでは、子ノードが2つある場合は、削除対象ノードのデータのみを置き換えて、その後右のsubtreeから置き換えノードを再帰的に削除することで、実装しています。

def _delete_node(self, data, cur_node):
if cur_node is None:
return cur_node
# 削除するデータが現在のノードより小さい場合は左にある
if data < cur_node.data:
cur_node.left = self._delete_node(data, cur_node.left)
# 削除するデータが現在のノードより大きい場合は右にある
elif data > cur_node.data:
cur_node.right = self._delete_node(data, cur_node.right)
# データを削除する
else:
# 子ノードが一つまたはなしの場合
if cur_node.left is None:
temp = cur_node.right
cur_node = None
return temp
elif cur_node.right is None:
temp = cur_node.left
cur_node.left = None
return temp
# 子ノードが2つの場合は右側の一番小さい値
temp = self.find_min(cur_node.right)
cur_node.data = temp.data
cur_node.right = self._delete_node(temp.data, cur_node.right)
return cur_node
def _delete_node(self, data, cur_node): if cur_node is None: return cur_node # 削除するデータが現在のノードより小さい場合は左にある if data < cur_node.data: cur_node.left = self._delete_node(data, cur_node.left) # 削除するデータが現在のノードより大きい場合は右にある elif data > cur_node.data: cur_node.right = self._delete_node(data, cur_node.right) # データを削除する else: # 子ノードが一つまたはなしの場合 if cur_node.left is None: temp = cur_node.right cur_node = None return temp elif cur_node.right is None: temp = cur_node.left cur_node.left = None return temp # 子ノードが2つの場合は右側の一番小さい値 temp = self.find_min(cur_node.right) cur_node.data = temp.data cur_node.right = self._delete_node(temp.data, cur_node.right) return cur_node
    def _delete_node(self, data, cur_node):
        if cur_node is None:
            return cur_node
        # 削除するデータが現在のノードより小さい場合は左にある
        if data < cur_node.data:
            cur_node.left = self._delete_node(data, cur_node.left)
        # 削除するデータが現在のノードより大きい場合は右にある
        elif data > cur_node.data:
            cur_node.right = self._delete_node(data, cur_node.right)
        # データを削除する
        else:
            # 子ノードが一つまたはなしの場合
            if cur_node.left is None:
                temp = cur_node.right
                cur_node = None
                return temp
            elif cur_node.right is None:
                temp = cur_node.left
                cur_node.left = None
                return temp
            
            # 子ノードが2つの場合は右側の一番小さい値
            temp = self.find_min(cur_node.right)
            cur_node.data = temp.data
            cur_node.right = self._delete_node(temp.data, cur_node.right)
        return cur_node

2分探索木コードまとめ

class Node:
def __init__(self, data=None):
self.data = data
self.left = None
self.right = None
class BST:
def __init__(self):
self.root = None
def insert(self, data):
if self.root is None:
self.root = Node(data)
else:
self._insert(data, self.root)
def _insert(self, data, cur_node):
if data < cur_node.data:
if cur_node.left is None:
cur_node.left = Node(data)
else:
self._insert(data, cur_node.left)
elif data > cur_node.data:
if cur_node.right is None:
cur_node.right = Node(data)
else:
self._insert(data, cur_node.right)
else:
print("同じ値があります。")
def inorder_print_tree(self):
if self.root:
self._inorder_print_tree(self.root)
def _inorder_print_tree(self, cur_node):
if cur_node:
self._inorder_print_tree(cur_node.left)
print(str(cur_node.data))
self._inorder_print_tree(cur_node.right)
def find_min(self, cur_node=None):
if self.root is None:
return None
while cur_node.left:
cur_node = cur_node.left
return cur_node
def delete_node(self, data):
if self.root is None:
return False
else:
self.root = self._delete_node(data, self.root)
def _delete_node(self, data, cur_node):
if cur_node is None:
return cur_node
# 削除するデータが現在のノードより小さい場合は左にある
if data < cur_node.data:
cur_node.left = self._delete_node(data, cur_node.left)
# 削除するデータが現在のノードより大きい場合は右にある
elif data > cur_node.data:
cur_node.right = self._delete_node(data, cur_node.right)
# データを削除する
else:
# 子ノードが一つまたはなしの場合
if cur_node.left is None:
temp = cur_node.right
cur_node = None
return temp
elif cur_node.right is None:
temp = cur_node.left
cur_node.left = None
return temp
# 子ノードが2つの場合は右側の一番小さい値
temp = self.find_min(cur_node.right)
cur_node.data = temp.data
cur_node.right = self._delete_node(temp.data, cur_node.right)
return cur_node
# 再帰
def find_rec(self, data):
if self.root:
is_found = self._find(data, self.root)
if is_found:
return True
return False
else:
return None
def _find(self, data, cur_node):
if data > cur_node.data and cur_node.right:
return self._find(data, cur_node.right)
elif data < cur_node.data and cur_node.left:
return self._find(data, cur_node.left)
if data == cur_node.data:
return True
# ループ
def find_by_while(self, data):
if self.root is None:
return None
cur_node = self.root
while cur_node:
if data == cur_node.data:
return True
if data < cur_node.data:
cur_node = cur_node.left
else:
cur_node = cur_node.right
return False
bst = BST()
bst.insert(8)
bst.insert(3)
bst.insert(10)
bst.insert(1)
bst.insert(6)
bst.insert(14)
bst.insert(4)
bst.insert(7)
bst.insert(13)
# 8
# /\
# 3 10
# /\ \
# 1 6 14
# /\ /
# 4 7 13
bst.inorder_print_tree()
# 1
# 3
# 4
# 6
# 7
# 8
# 10
# 13
# 14
# # True
print(bst.find_rec(6))
print(bst.find_by_while(6))
# # Flase
print(bst.find_rec(12))
print(bst.find_by_while(12))
bst.delete_node(3)
# 8
# /\
# 4 10
# /\ \
# 1 6 14
# \ /
# 7 13
bst.inorder_print_tree()
# 1
# 4
# 6
# 7
# 8
# 10
# 13
# 14
class Node: def __init__(self, data=None): self.data = data self.left = None self.right = None class BST: def __init__(self): self.root = None def insert(self, data): if self.root is None: self.root = Node(data) else: self._insert(data, self.root) def _insert(self, data, cur_node): if data < cur_node.data: if cur_node.left is None: cur_node.left = Node(data) else: self._insert(data, cur_node.left) elif data > cur_node.data: if cur_node.right is None: cur_node.right = Node(data) else: self._insert(data, cur_node.right) else: print("同じ値があります。") def inorder_print_tree(self): if self.root: self._inorder_print_tree(self.root) def _inorder_print_tree(self, cur_node): if cur_node: self._inorder_print_tree(cur_node.left) print(str(cur_node.data)) self._inorder_print_tree(cur_node.right) def find_min(self, cur_node=None): if self.root is None: return None while cur_node.left: cur_node = cur_node.left return cur_node def delete_node(self, data): if self.root is None: return False else: self.root = self._delete_node(data, self.root) def _delete_node(self, data, cur_node): if cur_node is None: return cur_node # 削除するデータが現在のノードより小さい場合は左にある if data < cur_node.data: cur_node.left = self._delete_node(data, cur_node.left) # 削除するデータが現在のノードより大きい場合は右にある elif data > cur_node.data: cur_node.right = self._delete_node(data, cur_node.right) # データを削除する else: # 子ノードが一つまたはなしの場合 if cur_node.left is None: temp = cur_node.right cur_node = None return temp elif cur_node.right is None: temp = cur_node.left cur_node.left = None return temp # 子ノードが2つの場合は右側の一番小さい値 temp = self.find_min(cur_node.right) cur_node.data = temp.data cur_node.right = self._delete_node(temp.data, cur_node.right) return cur_node # 再帰 def find_rec(self, data): if self.root: is_found = self._find(data, self.root) if is_found: return True return False else: return None def _find(self, data, cur_node): if data > cur_node.data and cur_node.right: return self._find(data, cur_node.right) elif data < cur_node.data and cur_node.left: return self._find(data, cur_node.left) if data == cur_node.data: return True # ループ def find_by_while(self, data): if self.root is None: return None cur_node = self.root while cur_node: if data == cur_node.data: return True if data < cur_node.data: cur_node = cur_node.left else: cur_node = cur_node.right return False bst = BST() bst.insert(8) bst.insert(3) bst.insert(10) bst.insert(1) bst.insert(6) bst.insert(14) bst.insert(4) bst.insert(7) bst.insert(13) # 8 # /\ # 3 10 # /\ \ # 1 6 14 # /\ / # 4 7 13 bst.inorder_print_tree() # 1 # 3 # 4 # 6 # 7 # 8 # 10 # 13 # 14 # # True print(bst.find_rec(6)) print(bst.find_by_while(6)) # # Flase print(bst.find_rec(12)) print(bst.find_by_while(12)) bst.delete_node(3) # 8 # /\ # 4 10 # /\ \ # 1 6 14 # \ / # 7 13 bst.inorder_print_tree() # 1 # 4 # 6 # 7 # 8 # 10 # 13 # 14
class Node:
    def __init__(self, data=None):
        self.data = data
        self.left = None
        self.right = None

class BST:
    def __init__(self):
        self.root = None

    def insert(self, data):
        if self.root is None:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    def _insert(self, data, cur_node):
        if data < cur_node.data:
            if cur_node.left is None:
                cur_node.left = Node(data)
            else:
                self._insert(data, cur_node.left)
        elif data > cur_node.data:
            if cur_node.right is None:
                cur_node.right = Node(data)
            else:
                self._insert(data, cur_node.right)
        else:
            print("同じ値があります。")

    def inorder_print_tree(self):
        if self.root:
            self._inorder_print_tree(self.root)

    def _inorder_print_tree(self, cur_node):
        if cur_node:
            self._inorder_print_tree(cur_node.left)
            print(str(cur_node.data))
            self._inorder_print_tree(cur_node.right)
            
    def find_min(self, cur_node=None):
        if self.root is None:
            return None
        while cur_node.left:
            cur_node = cur_node.left
        return cur_node

    def delete_node(self, data):
        if self.root is None:
            return False
        else:
            self.root = self._delete_node(data, self.root)

    def _delete_node(self, data, cur_node):
        if cur_node is None:
            return cur_node
        # 削除するデータが現在のノードより小さい場合は左にある
        if data < cur_node.data:
            cur_node.left = self._delete_node(data, cur_node.left)
        # 削除するデータが現在のノードより大きい場合は右にある
        elif data > cur_node.data:
            cur_node.right = self._delete_node(data, cur_node.right)
        # データを削除する
        else:
            # 子ノードが一つまたはなしの場合
            if cur_node.left is None:
                temp = cur_node.right
                cur_node = None
                return temp
            elif cur_node.right is None:
                temp = cur_node.left
                cur_node.left = None
                return temp
            
            # 子ノードが2つの場合は右側の一番小さい値
            temp = self.find_min(cur_node.right)
            cur_node.data = temp.data
            cur_node.right = self._delete_node(temp.data, cur_node.right)
        return cur_node
                

    # 再帰
    def find_rec(self, data):
        if self.root:
            is_found = self._find(data, self.root)
            if is_found:
                return True
            return False
        else:
            return None

    def _find(self, data, cur_node):
        if data > cur_node.data and cur_node.right:
            return self._find(data, cur_node.right)
        elif data < cur_node.data and cur_node.left:
            return self._find(data, cur_node.left)
        if data == cur_node.data:
            return True

    # ループ
    def find_by_while(self, data):
        if self.root is None:
            return None
        cur_node = self.root
        while cur_node:
            if data == cur_node.data:
                return True
            if data < cur_node.data:
                cur_node = cur_node.left
            else:
                cur_node = cur_node.right
        return False



bst = BST()

bst.insert(8)
bst.insert(3)
bst.insert(10)
bst.insert(1)
bst.insert(6)
bst.insert(14)
bst.insert(4)
bst.insert(7)
bst.insert(13)

#     8
#    /\
#   3  10
#  /\    \
# 1  6     14
#    /\    /
#   4  7  13

bst.inorder_print_tree()
# 1
# 3
# 4
# 6
# 7
# 8
# 10
# 13
# 14

# # True
print(bst.find_rec(6))
print(bst.find_by_while(6))
# # Flase
print(bst.find_rec(12))
print(bst.find_by_while(12))


bst.delete_node(3)
#     8
#    /\
#   4  10
#  /\    \
# 1  6     14
#     \    /
#      7  13

bst.inorder_print_tree()
# 1
# 4
# 6
# 7
# 8
# 10
# 13
# 14