[Python] 2分探索木が同一か判定

2つの2分探索木が同一かどうか判定します。

同じ場所のノードに同じデータを持つ場合を同一と判断します。

2分探索木は、以下の実装を使います。

2分探索木を Python で実装します。 以前、ほぼ全く同じ内容で記事を書いています。 ノード 2分探索木では、ノードは自身のデータと、0個、1個、2個のいずれかの子ノードを持っており、左側の子ノードは親の値より小さい値、右...

同じかどうかの判定を左の子と右の子に再帰的に行うことで判断できます。

判定するノードが葉ノードの子である場合(つまりノードが None の場合)がベースケースになり、この時は双方が None である必要があるのでノード自体を比較します。

それ以外は、ノードの持つデータを比較します。

def compare_bst(bst1, bst2):
    return _compare_bst(bst1.root, bst2.root)

def _compare_bst(node1, node2):
    # base case
    if (node1 is None) or (node2 is None):
        return node1 == node2
    
    if node1.data != node2.data:
        return False

    # AND を使うので、左右ともに同じなければ False
    return _compare_bst(node1.left_child, node2.left_child) \
        and _compare_bst(node1.right_child, node2.right_child)

テストします。

class Node(object):

    def __init__(self, data):
        self.data = data
        self.left_child = None
        self.right_child = None


class BinarySearchTree(object):

    def __init__(self):
        self.root = None

    def insert(self, data):
        if not self.root:
            self.root = Node(data)
        else:
            self._insert(data, self.root)

    # 木が平衡していればO(log(N)) 一直線に伸びているような場合はO(N)    
    def _insert(self, data, node):
        if data < node.data:
            if node.left_child:
                self._insert(data, node.left_child)
            else:
                node.left_child = Node(data)
        else:
            if node.right_child:
                self._insert(data, node.right_child)
            else:
                node.right_child = Node(data)

    def traverse_inorder(self):
        if self.root:
            self._traverse_inorder(self.root)
            # 改行対策
            print()
        else:
            print('2分木は空です。')

    def _traverse_inorder(self, node):
        if node.left_child:
            self._traverse_inorder(node.left_child)
        
        print(node.data, end=' ')

        if node.right_child:
            self._traverse_inorder(node.right_child)


def compare_bst(bst1, bst2):
    return _compare_bst(bst1.root, bst2.root)

def _compare_bst(node1, node2):
    # base case
    if (node1 is None) or (node2 is None):
        return node1 == node2
    
    if node1.data != node2.data:
        return False

    # AND を使うので、左右ともに同じなければ False
    return _compare_bst(node1.left_child, node2.left_child) \
        and _compare_bst(node1.right_child, node2.right_child)


if __name__ == '__main__':
    bst1 = BinarySearchTree()
    bst1.insert(8)
    bst1.insert(3)
    bst1.insert(10)
    bst1.insert(1)
    bst1.insert(6)
    bst1.insert(14)
    bst1.insert(4)
    bst1.insert(7)
    bst1.insert(13)

    bst2 = BinarySearchTree()
    bst2.insert(8)
    bst2.insert(3)
    bst2.insert(10)
    bst2.insert(1)
    bst2.insert(6)
    bst2.insert(14)
    bst2.insert(4)
    bst2.insert(7)
    bst2.insert(13)

    bst3 = BinarySearchTree()
    bst3.insert(14)
    bst3.insert(13)
    bst3.insert(10)
    bst3.insert(8)
    bst3.insert(7)
    bst3.insert(6)
    bst3.insert(4)
    bst3.insert(3)
    bst3.insert(1)
    
    # True bst1 bst2 は同一
    print(compare_bst(bst1, bst2))
    # False bst3 は一方向にのみ伸びた不均衡な木
    print(compare_bst(bst1, bst3))