[Python] 二分探索木/ Binary Search Tree

2分探索木をPythonで実装します。

二分探索木

二分探索木(にぶんたんさくぎ、: binary search tree)は、コンピュータプログラムにおいて、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」という制約を持つ二分木である。探索木のうちで最も基本的な木構造である。

出典: フリー百科事典『ウィキペディア(Wikipedia)』

探索の計算量は平均的には O(log2 N)  と、通常の2分木と比較して減らせる可能性のある構造となります。

上の定義では、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」となっていますが、ここでは、 「左の子孫の値 < 親の値 < 右の子孫の値」と考えます。

ノードの定義

ノードを定義します。

通常の2分木と変わりません。

class Node:
    def __init__(self, data=None):
        self.data = data
        self.left = None
        self.right = None

BSTクラス

2分探索木のクラスをBSTクラスとして定義します。

class BST:
    def __init__(self):
        self.root = None

挿入

BSTクラスのメソッドとして、データの挿入を定義します。

アルゴリズムは以下のように考えます。

ルートから手順を開始する。
着目しているノードと目的の値を比較する。「目的の値 < 着目しているノード」なら左の子、「着目しているノード ≤ 目的の値」なら右の子が、次の着目ノードとなる。
次の着目ノードが存在しなければ(現在の着目ノードが葉であれば)、次の着目ノードの位置にデータを挿入。存在すれば、次の着目ノードに移って繰り返し。
挿入の計算量は木の高さに比例し、平衡状態であれば O(log N) となる。

出典: フリー百科事典『ウィキペディア(Wikipedia)』
    def insert(self, data):
        if self.root is None:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    def _insert(self, data, cur_node):
        if data < cur_node.data:
            if cur_node.left is None:
                cur_node.left = Node(data)
            else:
                self._insert(data, cur_node.left)
        elif data > cur_node.data:
            if cur_node.right is None:
                cur_node.right = Node(data)
            else:
                self._insert(data, cur_node.right)
        else:
            print("同じ値があります。")

データの列挙

BSTクラスのメソッドとして、全データのプリントを定義します。

アルゴリズムは以下のように考えます。

以下のように 再帰呼び出し を使うことで、二分探索木に登録された全データをソートされた順序で列挙できる。
左の子をルートとする部分木に対して、この処理を再帰的に適用する。
親を表示する。
右の子をルートとする部分木に対して、この処理を再帰的に適用する。

出典: フリー百科事典『ウィキペディア(Wikipedia)』
    def inorder_print_tree(self):
        if self.root:
            self._inorder_print_tree(self.root)

    def _inorder_print_tree(self, cur_node):
        if cur_node:
            self._inorder_print_tree(cur_node.left)
            print(str(cur_node.data))
            self._inorder_print_tree(cur_node.right)

データの探索

BSTクラスのメソッドとして、あるデータの探索を定義します。

アルゴリズムは以下のように考えます。

ルートから手順を開始する。
着目しているノードと目的の値を比較する。等しいか、着目ノードが存在しなければ終了。
「目的の値 < 着目しているノード」なら左の子、「着目しているノード < 目的の値」なら右の子へ移って繰り返し。
探索の計算量は木の高さに比例し、平衡状態であれば O(log N) となる。

出典: フリー百科事典『ウィキペディア(Wikipedia)』

再帰的に書きます。

    def find_rec(self, data):
        if self.root:
            is_found = self._find(data, self.root)
            if is_found:
                return True
            return False
        else:
            return None

    def _find(self, data, cur_node):
        if data > cur_node.data and cur_node.right:
            return self._find(data, cur_node.right)
        elif data < cur_node.data and cur_node.left:
            return self._find(data, cur_node.left)
        if data == cur_node.data:
            return True

ループを使って書きます。

def find_by_while(self, data):
        if self.root is None:
            return None
        cur_node = self.root
        while cur_node:
            if data == cur_node.data:
                return True
            if data < cur_node.data:
                cur_node = cur_node.left
            else:
                cur_node = cur_node.right
        return False

データの削除

BSTクラスのメソッドとして、あるデータの削除を定義します。

アルゴリズムは以下のように考えます。

1.ルートから手順を開始する。
2.着目しているノードと目的の値を比較する。「目的の値 < 着目しているノード」なら左の子、「着目しているノード ≤ 目的の値」なら右の子が、次に着目するノードとなる。
3.着目ノードが削除する対象(以下、削除ノード)であり、削除ノードが子どもを持たないなら、そのノードをそのまま削除する。
4.削除ノードが子を一つしかもっていない場合は、削除ノードを削除してその子と置き換える。
5.削除ノードが子を二つ持つ場合

5-1.削除ノードの右の子から最小の値を探索する。
5-2.1 で探索してきたノード(以下、探索ノード)を削除対象のノードと置き換えて、削除対象のノードを削除する。このとき 探索ノードの右の子を探索ノードの元位置に置き換える。

出典: フリー百科事典『ウィキペディア(Wikipedia)』

以下のコードでは、子ノードが2つある場合は、削除対象ノードのデータのみを置き換えて、その後右のsubtreeから置き換えノードを再帰的に削除することで、実装しています。

    def _delete_node(self, data, cur_node):
        if cur_node is None:
            return cur_node
        # 削除するデータが現在のノードより小さい場合は左にある
        if data < cur_node.data:
            cur_node.left = self._delete_node(data, cur_node.left)
        # 削除するデータが現在のノードより大きい場合は右にある
        elif data > cur_node.data:
            cur_node.right = self._delete_node(data, cur_node.right)
        # データを削除する
        else:
            # 子ノードが一つまたはなしの場合
            if cur_node.left is None:
                temp = cur_node.right
                cur_node = None
                return temp
            elif cur_node.right is None:
                temp = cur_node.left
                cur_node.left = None
                return temp
            
            # 子ノードが2つの場合は右側の一番小さい値
            temp = self.find_min(cur_node.right)
            cur_node.data = temp.data
            cur_node.right = self._delete_node(temp.data, cur_node.right)
        return cur_node

2分探索木コードまとめ

class Node:
    def __init__(self, data=None):
        self.data = data
        self.left = None
        self.right = None

class BST:
    def __init__(self):
        self.root = None

    def insert(self, data):
        if self.root is None:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    def _insert(self, data, cur_node):
        if data < cur_node.data:
            if cur_node.left is None:
                cur_node.left = Node(data)
            else:
                self._insert(data, cur_node.left)
        elif data > cur_node.data:
            if cur_node.right is None:
                cur_node.right = Node(data)
            else:
                self._insert(data, cur_node.right)
        else:
            print("同じ値があります。")

    def inorder_print_tree(self):
        if self.root:
            self._inorder_print_tree(self.root)

    def _inorder_print_tree(self, cur_node):
        if cur_node:
            self._inorder_print_tree(cur_node.left)
            print(str(cur_node.data))
            self._inorder_print_tree(cur_node.right)
            
    def find_min(self, cur_node=None):
        if self.root is None:
            return None
        while cur_node.left:
            cur_node = cur_node.left
        return cur_node

    def delete_node(self, data):
        if self.root is None:
            return False
        else:
            self.root = self._delete_node(data, self.root)

    def _delete_node(self, data, cur_node):
        if cur_node is None:
            return cur_node
        # 削除するデータが現在のノードより小さい場合は左にある
        if data < cur_node.data:
            cur_node.left = self._delete_node(data, cur_node.left)
        # 削除するデータが現在のノードより大きい場合は右にある
        elif data > cur_node.data:
            cur_node.right = self._delete_node(data, cur_node.right)
        # データを削除する
        else:
            # 子ノードが一つまたはなしの場合
            if cur_node.left is None:
                temp = cur_node.right
                cur_node = None
                return temp
            elif cur_node.right is None:
                temp = cur_node.left
                cur_node.left = None
                return temp
            
            # 子ノードが2つの場合は右側の一番小さい値
            temp = self.find_min(cur_node.right)
            cur_node.data = temp.data
            cur_node.right = self._delete_node(temp.data, cur_node.right)
        return cur_node
                

    # 再帰
    def find_rec(self, data):
        if self.root:
            is_found = self._find(data, self.root)
            if is_found:
                return True
            return False
        else:
            return None

    def _find(self, data, cur_node):
        if data > cur_node.data and cur_node.right:
            return self._find(data, cur_node.right)
        elif data < cur_node.data and cur_node.left:
            return self._find(data, cur_node.left)
        if data == cur_node.data:
            return True

    # ループ
    def find_by_while(self, data):
        if self.root is None:
            return None
        cur_node = self.root
        while cur_node:
            if data == cur_node.data:
                return True
            if data < cur_node.data:
                cur_node = cur_node.left
            else:
                cur_node = cur_node.right
        return False



bst = BST()

bst.insert(8)
bst.insert(3)
bst.insert(10)
bst.insert(1)
bst.insert(6)
bst.insert(14)
bst.insert(4)
bst.insert(7)
bst.insert(13)

#     8
#    /\
#   3  10
#  /\    \
# 1  6     14
#    /\    /
#   4  7  13

bst.inorder_print_tree()
# 1
# 3
# 4
# 6
# 7
# 8
# 10
# 13
# 14

# # True
print(bst.find_rec(6))
print(bst.find_by_while(6))
# # Flase
print(bst.find_rec(12))
print(bst.find_by_while(12))


bst.delete_node(3)
#     8
#    /\
#   4  10
#  /\    \
# 1  6     14
#     \    /
#      7  13

bst.inorder_print_tree()
# 1
# 4
# 6
# 7
# 8
# 10
# 13
# 14