[データ構造] 2分探索木の Python での実装

二分探索木 二分探索木(にぶんたんさくぎ、英:binary search tree)は、コンピュータプログラムにおいて、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」という制約を持つ二分木である。探索木のうちで最も基本的な木構造である。 出典: フリ...

2分探索木を Python で実装します。

以前、ほぼ全く同じ内容で記事を書いています。

2分探索木をPythonで実装します。 二分探索木 二分探索木(にぶんたんさくぎ、英:binary search tree)は、コンピュータプログラムにおいて、「左の子孫の値 ≤ 親の値 ≤ 右の子孫の値」という制約を持つ二分木である。探索木のうちで最も基本的な木...

ノード

2分探索木では、ノードは自身のデータと、0個、1個、2個のいずれかの子ノードを持っており、左側の子ノードは親の値より小さい値、右の側の子ノードは親の値より大きい値を持ちます。

# New added.
class Node(object):

    def __init__(self, data):
        self.data = data
        self.left_child = None
        self.right_child = None

データの挿入

データの挿入は、もしまだ「ルートがない」つまり最初のノードを挿入する場合は、そのノードをルートに設定します。

既にルートが設定されている場合は、データとルートを比較して小さい場合は左、大きい場合は右に進みます。

進んだ方向にノードが無ければ、データをその箇所のノードに設定します。

進んだ方向にノードがあれば、そのノードとデータの大小を比較して、また左右に分かれます。この処理は再帰的に実行します。

挿入箇所を探す際に、このようにデータの大小とノードの大小を比較して左右どちらかの方向に進んでいくことで、木が平衡していれば、\( O (\log N) \) の計算量で挿入が行えます。

class Node(object):

    def __init__(self, data):
        self.data = data
        self.left_child = None
        self.right_child = None


#New added.
class BinarySearchTree(object):

    def __init__(self):
        self.root = None

    def insert(self, data):
        if not self.root:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)    
    def _insert(self, data, node):
        if data < node.data:
            if node.left_child:
                self._insert(data, node.left_child)
            else:
                node.left_child = Node(data)
        else:
            if node.right_child:
                self._insert(data, node.right_child)
            else:
                node.right_child = Node(data)

最小値・最大値の探索

2分探索木の中で一番左側にあるノードが最小値となります。

左側にノードがある場合は再帰的に最小値の探索を行い、左側にノードがない場合はそのノードが一番左側にあるので、最小値になります。

木が平衡していれば、\( O (\log N) \) の計算量になります。

同様に、 2分探索木の中で一番右側にあるノードが最大値となり、再帰的に処理します。

 class Node(object):

    def __init__(self, data):
        self.data = data
        self.left_child = None
        self.right_child = None


class BinarySearchTree(object):

    def __init__(self):
        self.root = None

    def insert(self, data):
        if not self.root:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)    
    def _insert(self, data, node):
        if data < node.data:
            if node.left_child:
                self._insert(data, node.left_child)
            else:
                node.left_child = Node(data)
        else:
            if node.right_child:
                self._insert(data, node.right_child)
            else:
                node.right_child = Node(data)

    # New Added.
    def get_min(self):
        if self.root:
            return self._get_min(self.root)
        return False

    def _get_min(self, node):
        if node.left_child:
            return self._get_min(node.left_child)
        return node.data

    def get_max(self):
        if self.root:
            return self._get_max(self.root)
        return False

    def _get_max(self, node):
        if node.right_child:
            return self._get_max(node.right_child)
        return node.data

in-order traversal

in-order traversal は、「左側の部分木」→「ルート」→「右側の部分木」という順序で再帰的に木を巡ります。

class Node(object):

    def __init__(self, data):
        self.data = data
        self.left_child = None
        self.right_child = None


class BinarySearchTree(object):

    def __init__(self):
        self.root = None

    def insert(self, data):
        if not self.root:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)    
    def _insert(self, data, node):
        if data < node.data:
            if node.left_child:
                self._insert(data, node.left_child)
            else:
                node.left_child = Node(data)
        else:
            if node.right_child:
                self._insert(data, node.right_child)
            else:
                node.right_child = Node(data)

    def get_min(self):
        if self.root:
            return self._get_min(self.root)
        return False

    def _get_min(self, node):
        if node.left_child:
            return self._get_min(node.left_child)
        return node.data

    def get_max(self):
        if self.root:
            return self._get_max(self.root)
        return False

    def _get_max(self, node):
        if node.right_child:
            return self._get_max(node.right_child)
        return node.data

    # New Added.
    def traverse_inorder(self):
        if self.root:
            self._traverse_inorder(self.root)
        else:
            print('2分木は空です。')

    def _traverse_inorder(self, node):
        if node.left_child:
            self._traverse_inorder(node.left_child)
        
        print(node.data)

        if node.right_child:
            self._traverse_inorder(node.right_child)

2分探索木を走査できるようになったので、 以下のような木になるようデータを挿入して、木の走査、最大値、最小値を求めてみます。

class Node(object):

    def __init__(self, data):
        self.data = data
        self.left_child = None
        self.right_child = None


class BinarySearchTree(object):

    def __init__(self):
        self.root = None

    def insert(self, data):
        if not self.root:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)    
    def _insert(self, data, node):
        if data < node.data:
            if node.left_child:
                self._insert(data, node.left_child)
            else:
                node.left_child = Node(data)
        else:
            if node.right_child:
                self._insert(data, node.right_child)
            else:
                node.right_child = Node(data)

    def get_min(self):
        if self.root:
            return self._get_min(self.root)
        return False

    def _get_min(self, node):
        if node.left_child:
            return self._get_min(node.left_child)
        return node.data

    def get_max(self):
        if self.root:
            return self._get_max(self.root)
        return False

    def _get_max(self, node):
        if node.right_child:
            return self._get_max(node.right_child)
        return node.data

    def traverse_inorder(self):
        if self.root:
            self._traverse_inorder(self.root)
            # 改行対策
            print()
        else:
            print('2分木は空です。')

    def _traverse_inorder(self, node):
        if node.left_child:
            self._traverse_inorder(node.left_child)
        
        print(node.data, end=' ')

        if node.right_child:
            self._traverse_inorder(node.right_child)

    
if __name__ == '__main__':
    bst = BinarySearchTree()
    # 2分木は空です。
    bst.traverse_inorder()
    
    bst.insert(8)
    bst.insert(3)
    bst.insert(10)
    bst.insert(1)
    bst.insert(6)
    bst.insert(14)
    bst.insert(4)
    bst.insert(7)
    bst.insert(13)
    
    # ソート順 1 3 4 6 7 8 10 13 14 
    bst.traverse_inorder()
    
    # MIN 1
    print('MIN', bst.get_min())
    # MAX 14
    print('MAX', bst.get_max())

ノードの削除

場合分けが多いのと、ポインタを直接扱うことができないので、コードが少しややこしくなります。

削除を行うヘルパー関数 _remove は、親ノードの子ノードへのポインタを再設定するために、子ノードを返しています。

平衡木であれば、計算量は\( O (\log N) \) です。

class Node(object):

    def __init__(self, data):
        self.data = data
        self.left_child = None
        self.right_child = None


class BinarySearchTree(object):

    def __init__(self):
        self.root = None

    def insert(self, data):
        if not self.root:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)    
    def _insert(self, data, node):
        if data < node.data:
            if node.left_child:
                self._insert(data, node.left_child)
            else:
                node.left_child = Node(data)
        else:
            if node.right_child:
                self._insert(data, node.right_child)
            else:
                node.right_child = Node(data)

    def get_min(self):
        if self.root:
            return self._get_min(self.root)
        return False

    def _get_min(self, node):
        if node.left_child:
            return self._get_min(node.left_child)
        return node.data

    def get_max(self):
        if self.root:
            return self._get_max(self.root)
        return False

    def _get_max(self, node):
        if node.right_child:
            return self._get_max(node.right_child)
        return node.data

    def traverse_inorder(self):
        if self.root:
            self._traverse_inorder(self.root)
            # 改行対策
            print()
        else:
            print('2分木は空です。')

    def _traverse_inorder(self, node):
        if node.left_child:
            self._traverse_inorder(node.left_child)
        
        print(node.data, end=' ')

        if node.right_child:
            self._traverse_inorder(node.right_child)
    
    # new added.
    def remove(self, data):
        if self.root:
            self.root = self._remove(data, self.root)
        return False

    def _remove(self, data, node):
        if not node:
            return node
        
        if data < node.data:
            node.left_child = self._remove(data, node.left_child)
        elif data > node.data:
            node.right_child = self._remove(data, node.right_child)
        else:
            # 削除ノードが子を持たない場合は、ノードを削除する。
            if not node.right_child and not node.left_child:
                del node
                # None を返すことで、親ノードの削除子ノードへのポインタを None に変更
                return None
            
            # 削除ノードが左の子だけを持つ場合、ノードを削除し、左の子のノードを返す
            if not node.right_child:
                temp = node.left_child
                del node
                # 左の子のノードを返すことで、親ノードの削除子ノードへのポインタを新しい子ノードに変更
                return temp
            
            # 削除ノードが右の子だけを持つ場合、左の子だけの場合の逆の操作を行う
            if not node.left_child:
                temp = node.right_child
                del node
                return temp

            # 削除ノードが左右の子を持つ場合、ここでは、左側のsubtreeの最大のノードを代わりのノードにすることにする。
            #  subtreeの最大のノードを取得するヘルパー関数
            def _get_max_node(node):
                if node.right_child:
                    return _get_max_node(node.right_child)
                return node

            temp = _get_max_node(node.left_child)
            node.data = temp.data
            # 左側のsubtreeから削除ノードと入れ替えたノードを削除
            node.left_child = self._remove(temp.data, node.left_child)
        
        return node

子のない要素の削除として、「1」と「13」を削除してみます。

if __name__ == '__main__':
    bst = BinarySearchTree()
    
    bst.insert(8)
    bst.insert(3)
    bst.insert(10)
    bst.insert(1)
    bst.insert(6)
    bst.insert(14)
    bst.insert(4)
    bst.insert(7)
    bst.insert(13)
    
    # 1 3 4 6 7 8 10 13 14 
    bst.traverse_inorder()
    # 子のない要素を削除

    bst.remove(1)
    # 3 4 6 7 8 10 13 14 
    bst.traverse_inorder()

    bst.remove(13)
    # 3 4 6 7 8 10 14
    bst.traverse_inorder()

子を一つだけ持つ要素の削除として、「10」「14」を削除してみます。

if __name__ == '__main__':
    bst = BinarySearchTree()
    
    bst.insert(8)
    bst.insert(3)
    bst.insert(10)
    bst.insert(1)
    bst.insert(6)
    bst.insert(14)
    bst.insert(4)
    bst.insert(7)
    bst.insert(13)
    
    # 1 3 4 6 7 8 10 13 14 
    bst.traverse_inorder()
    # 子を一つ持つ要素を削除

    bst.remove(10)
    # 1 3 4 6 7 8 13 14 
    bst.traverse_inorder()

    bst.remove(14)
    # 1 3 4 6 7 8 13
    bst.traverse_inorder()

子を2つ持つ要素としてルートを削除してみます。

if __name__ == '__main__':
    bst = BinarySearchTree()
    
    bst.insert(8)
    bst.insert(3)
    bst.insert(10)
    bst.insert(1)
    bst.insert(6)
    bst.insert(14)
    bst.insert(4)
    bst.insert(7)
    bst.insert(13)
    
    # 1 3 4 6 7 8 10 13 14 
    bst.traverse_inorder()
    # ルートを削除

    bst.remove(8)
    # 1 3 4 6 7 10 13 14  
    bst.traverse_inorder()

それぞれ想定通りに動作しています。

平衡二分探索木 平衡二分探索木(へいこうにぶんたんさくぎ、英:self-balancing binary search tree)とは、計算機科学において二分探索木のうち木の高さ(根からの階層の数)を自動的にできるだけ小さく維持しようとするもの(平衡木...