Python でAVL木を実装します。
ノード
2分探索木とあまり変わりません。
height というインスタンス変数を設定し、ノードの高さを保持します。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
self.height = 0
AVL
平衡係数の取得
まずは、ノードの高さを取得する関数 get_height を定義します。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
self.height = 0
class AVL(object):
def __init__(self):
self.root = None
def get_height(self, node):
if not node:
return -1
return node.height
def get_balance(self, node):
'''
返り値が1より大きい場合、左の部分木が重い → 右回転
返り値が-1より小さい場合、右の部分木が重い → 左回転
'''
if not node:
return 0
return self.get_height(node.left_child) - self.get_height(node.right_child)
get_heightを使い、左右の部分木の高さの差 (=平衡係数)を取得する関数get_balanceを設定します。
この関数の返り値が1より大きい(平衡係数> 1 )場合、左の部分木が重いということで、木の平衡を保つため、最終的に右回転を行います。逆に返り値が-1より小さい (平衡係数> ‐1 ) 場合、右の部分木が重いので、最終的に左回転を行います。
回転
右回転を考えます。

まずは、tempに対象ノードを入れて、子ノードのポインタ付け替えを行います。
元々のルートと左側の子は、それぞれ子ノードが変わり高さが変わっているので、高さの再計算を行います。
最後にルートを再設定するため、新しいルートを返します。
左回転は、これと逆の操作になります。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
self.height = 0
class AVL(object):
def __init__(self):
self.root = None
def get_height(self, node):
if not node:
return -1
return node.height
def get_balance(self, node):
'''
返り値が1より大きい場合、左の部分木が重い → 右回転
返り値が-1より小さい場合、右の部分木が重い → 左回転
'''
if not node:
return 0
return self.get_height(node.left_child) - self.get_height(node.right_child)
def rotate_right(self, node):
temp_left_child = node.left_child
temp_left_right_child = temp_left_child.right_child
temp_left_child.right_child = node
node.left_child = temp_left_right_child
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
temp_left_child.height = max(self.get_height(temp_left_child.left_child), self.get_height(temp_left_child.right_child)) + 1
return temp_left_child
def rotate_left(self, node):
temp_right_child = node.right_child
temp_right_left_child = temp_right_child.left_child
temp_right_child.left_child = node
node.right_child = temp_right_left_child
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
temp_right_child.height = max(self.get_height(temp_right_child.left_child), self.get_height(temp_right_child.right_child)) + 1
return temp_right_child
挿入
データの挿入を行います。
2分探索と同様に再帰的に処理し、ポインタの付け替えは、_insert の返り値を用いて行います。
ノードの高さの更新を行い、それに応じて木の平衡を保ちます。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
self.height = 0
class AVL(object):
def __init__(self):
self.root = None
def get_height(self, node):
if not node:
return -1
return node.height
def get_balance(self, node):
'''
返り値が1より大きい場合、左の部分木が重い → 右回転
返り値が-1より小さい場合、右の部分木が重い → 左回転
'''
if not node:
return 0
return self.get_height(node.left_child) - self.get_height(node.right_child)
def rotate_right(self, node):
temp_left_child = node.left_child
temp_left_right_child = temp_left_child.right_child
temp_left_child.right_child = node
node.left_child = temp_left_right_child
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
temp_left_child.height = max(self.get_height(temp_left_child.left_child), self.get_height(temp_left_child.right_child)) + 1
return temp_left_child
def rotate_left(self, node):
temp_right_child = node.right_child
temp_right_left_child = temp_right_child.left_child
temp_right_child.left_child = node
node.right_child = temp_right_left_child
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
temp_right_child.height = max(self.get_height(temp_right_child.left_child), self.get_height(temp_right_child.right_child)) + 1
return temp_right_child
def insert(self, data):
print(f'--------------')
print(f'start inserting {data}')
self.root = self._insert(data, self.root)
print(f'finish inserting {data}')
print(f'--------------')
def _insert(self, data, node):
if not node:
return Node(data)
if data < node.data:
node.left_child = self._insert(data, node.left_child)
else:
node.right_child = self._insert(data, node.right_child)
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
return self.settle_unbalance(data, node)
def settle_unbalance(self, data, node):
balance = self.get_balance(node)
# balace >1 -> 左の方が重い
# data < node.left_child.data 左の子より左側にデータが挿入されたので、左の子の左の子の方が重い
# つまり、left-left heavy 右回転を行う
if balance > 1 and data < node.left_child.data:
print('left-left heavy')
return self.rotate_right(node)
# balace > -1 -> 右の方が重い
# data > node.right_child.data 右の子より右側にデータが挿入されたので、右の子の右の子の方が重い
# つまり、right-right heavy 左回転を行う
if balance < -1 and data > node.right_child.data:
print('right-right heavy')
return self.rotate_left(node)
# balace >1 -> 左の方が重い
# data > node.left_child.data 左の子より右側にデータが挿入されたので、左の子の右の子の方が重い
# つまり、left-right heavy 左回転を行い右回転を行う
if balance > 1 and data > node.left_child.data:
print('left-right heavy')
node.left_child = self.rotate_left(node.left_child)
return self.rotate_right(node)
# balace > -1 -> 右の方が重い
# data < node.right_child.data 右の子より左側にデータが挿入されたので、右の子の左の子の方が重い
# つまり、right-left heavy 右回転を行い左回転を行う
if balance < -1 and data < node.right_child.data:
node.right_child = self.rotate_right(node.right_child)
print('right-left heavy')
return self.rotate_left(node)
return node
2分探索木と同様の in-order traversal を使い、挿入が正しく行われているか確認をします。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
self.height = 0
class AVL(object):
def __init__(self):
self.root = None
def get_height(self, node):
if not node:
return -1
return node.height
def get_balance(self, node):
'''
返り値が1より大きい場合、左の部分木が重い → 右回転
返り値が-1より小さい場合、右の部分木が重い → 左回転
'''
if not node:
return 0
return self.get_height(node.left_child) - self.get_height(node.right_child)
def rotate_right(self, node):
temp_left_child = node.left_child
temp_left_right_child = temp_left_child.right_child
temp_left_child.right_child = node
node.left_child = temp_left_right_child
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
temp_left_child.height = max(self.get_height(temp_left_child.left_child), self.get_height(temp_left_child.right_child)) + 1
return temp_left_child
def rotate_left(self, node):
temp_right_child = node.right_child
temp_right_left_child = temp_right_child.left_child
temp_right_child.left_child = node
node.right_child = temp_right_left_child
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
temp_right_child.height = max(self.get_height(temp_right_child.left_child), self.get_height(temp_right_child.right_child)) + 1
return temp_right_child
def insert(self, data):
print(f'--------------')
print(f'start inserting {data}')
self.root = self._insert(data, self.root)
print(f'finish inserting {data}')
print(f'--------------')
def _insert(self, data, node):
if not node:
return Node(data)
if data < node.data:
node.left_child = self._insert(data, node.left_child)
else:
node.right_child = self._insert(data, node.right_child)
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
return self.settle_unbalance(data, node)
def settle_unbalance(self, data, node):
balance = self.get_balance(node)
# balace >1 -> 左の方が重い
# data < node.left_child.data 左の子より左側にデータが挿入されたので、左の子の左の子の方が重い
# つまり、left-left heavy 右回転を行う
if balance > 1 and data < node.left_child.data:
print('left-left heavy')
return self.rotate_right(node)
# balace > -1 -> 右の方が重い
# data > node.right_child.data 右の子より右側にデータが挿入されたので、右の子の右の子の方が重い
# つまり、right-right heavy 左回転を行う
if balance < -1 and data > node.right_child.data:
print('right-right heavy')
return self.rotate_left(node)
# balace >1 -> 左の方が重い
# data > node.left_child.data 左の子より右側にデータが挿入されたので、左の子の右の子の方が重い
# つまり、left-right heavy 左回転を行い右回転を行う
if balance > 1 and data > node.left_child.data:
print('left-right heavy')
node.left_child = self.rotate_left(node.left_child)
return self.rotate_right(node)
# balace > -1 -> 右の方が重い
# data < node.right_child.data 右の子より左側にデータが挿入されたので、右の子の左の子の方が重い
# つまり、right-left heavy 右回転を行い左回転を行う
if balance < -1 and data < node.right_child.data:
node.right_child = self.rotate_right(node.right_child)
print('right-left heavy')
return self.rotate_left(node)
return node
def traverse_inorder(self):
if self.root:
self._traverse_inorder(self.root)
# 改行対策
print()
else:
print('木は空です。')
def _traverse_inorder(self, node):
if node.left_child:
self._traverse_inorder(node.left_child)
print(node.data, end=' ')
if node.right_child:
self._traverse_inorder(node.right_child)
if __name__ == '__main__':
avl = AVL()
avl.insert(2)
avl.insert(4)
# right-right heavy
avl.insert(8)
avl.insert(16)
# right-right heavy
avl.insert(32)
# right-left heavy
avl.insert(14)
# 2 4 8 14 16 32
avl.traverse_inorder()
# height 2
print('height', avl.get_height(avl.root))
avl = AVL()
avl.insert(32)
avl.insert(16)
# left-left heavy
avl.insert(8)
avl.insert(4)
# left-left heavy
avl.insert(2)
# left-right heavy
avl.insert(14)
# 2 4 8 14 16 32
avl.traverse_inorder()
# height 2
print('height', avl.get_height(avl.root))
削除
2分探索木と同様に再帰的にノード削除を行い、その後それぞれのノードの平衡を確認します。
class Node(object):
def __init__(self, data):
self.data = data
self.left_child = None
self.right_child = None
self.height = 0
class AVL(object):
def __init__(self):
self.root = None
def get_height(self, node):
if not node:
return -1
return node.height
def get_balance(self, node):
'''
返り値が1より大きい場合、左の部分木が重い → 右回転
返り値が-1より小さい場合、右の部分木が重い → 左回転
'''
if not node:
return 0
return self.get_height(node.left_child) - self.get_height(node.right_child)
def rotate_right(self, node):
temp_left_child = node.left_child
temp_left_right_child = temp_left_child.right_child
temp_left_child.right_child = node
node.left_child = temp_left_right_child
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
temp_left_child.height = max(self.get_height(temp_left_child.left_child), self.get_height(temp_left_child.right_child)) + 1
return temp_left_child
def rotate_left(self, node):
temp_right_child = node.right_child
temp_right_left_child = temp_right_child.left_child
temp_right_child.left_child = node
node.right_child = temp_right_left_child
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
temp_right_child.height = max(self.get_height(temp_right_child.left_child), self.get_height(temp_right_child.right_child)) + 1
return temp_right_child
def insert(self, data):
self.root = self._insert(data, self.root)
def _insert(self, data, node):
if not node:
return Node(data)
if data < node.data:
node.left_child = self._insert(data, node.left_child)
else:
node.right_child = self._insert(data, node.right_child)
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
return self.settle_unbalance(data, node)
def settle_unbalance(self, data, node):
balance = self.get_balance(node)
# balace >1 -> 左の方が重い
# data < node.left_child.data 左の子より左側にデータが挿入されたので、左の子の左の子の方が重い
# つまり、left-left heavy 右回転を行う
if balance > 1 and data < node.left_child.data:
return self.rotate_right(node)
# balace > -1 -> 右の方が重い
# data > node.right_child.data 右の子より右側にデータが挿入されたので、右の子の右の子の方が重い
# つまり、right-right heavy 左回転を行う
if balance < -1 and data > node.right_child.data:
return self.rotate_left(node)
# balace >1 -> 左の方が重い
# data > node.left_child.data 左の子より右側にデータが挿入されたので、左の子の右の子の方が重い
# つまり、left-right heavy 左回転を行い右回転を行う
if balance > 1 and data > node.left_child.data:
node.left_child = self.rotate_left(node.left_child)
return self.rotate_right(node)
# balace > -1 -> 右の方が重い
# data < node.right_child.data 右の子より左側にデータが挿入されたので、右の子の左の子の方が重い
# つまり、right-left heavy 右回転を行い左回転を行う
if balance < -1 and data < node.right_child.data:
node.right_child = self.rotate_right(node.right_child)
return self.rotate_left(node)
return node
def traverse_inorder(self):
if self.root:
self._traverse_inorder(self.root)
# 改行対策
print()
else:
print('木は空です。')
def _traverse_inorder(self, node):
if node.left_child:
self._traverse_inorder(node.left_child)
print(node.data, end=' ')
if node.right_child:
self._traverse_inorder(node.right_child)
def remove(self, data):
if self.root:
self.root = self._remove(data, self.root)
def _remove(self, data, node):
if data < node.data:
node.left_child = self._remove(data, node.left_child)
elif data > node.data:
node.right_child = self._remove(data, node.right_child)
else:
# 削除ノードが子を持たない場合は、ノードを削除する。
if not node.right_child and not node.left_child:
del node
# None を返すことで、親ノードの削除子ノードへのポインタを None に変更
return None
# 削除ノードが左の子だけを持つ場合、ノードを削除し、左の子のノードを返す
if not node.right_child:
temp = node.left_child
del node
# 左の子のノードを返すことで、親ノードの削除子ノードへのポインタを新しい子ノードに変更
return temp
# 削除ノードが右の子だけを持つ場合、左の子だけの場合の逆の操作を行う
if not node.left_child:
temp = node.right_child
del node
return temp
# 削除ノードが左右の子を持つ場合、ここでは、左側のsubtreeの最大のノードを代わりのノードにすることにする。
# subtreeの最大のノードを取得するヘルパー関数
def _get_max_node(node):
if node.right_child:
return _get_max_node(node.right_child)
return node
temp = _get_max_node(node.left_child)
node.data = temp.data
# 左側のsubtreeから削除ノードと入れ替えたノードを削除
node.left_child = self._remove(temp.data, node.left_child)
# 単体の木の場合
if not node:
return node
node.height = max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
balance = self.get_balance(node)
if balance > 1 and self.get_balance(node.left_child) >= 0:
print('left-left heavy')
return self.rotate_right(node)
if balance < -1 and self.get_balance(node.left_child) <= 0:
print('right-right heavy')
return self.rotate_left(node)
if balance > 1 and self.get_balance(node.left_child) < 0:
print('left-right heavy')
node.left_child = self.rotate_left(node.left_child)
return self.rotate_right(node)
if balance < -1 and self.get_balance(node.left_child) > 0:
node.right_child = self.rotate_right(node.right_child)
print('right-left heavy')
return self.rotate_left(node)
return node
if __name__ == '__main__':
avl = AVL()
avl.insert(2)
avl.insert(4)
avl.insert(8)
avl.insert(16)
avl.insert(32)
avl.insert(64)
avl.insert(128)
avl.insert(256)
avl.insert(512)
# 2 4 8 16 32 64 128 256 512
avl.traverse_inorder()
avl.remove(2)
# right-right heavy
avl.remove(4)
avl.remove(512)
avl.remove(256)
# left-left heavy
avl.remove(128)
avl.remove(64)
# 8 16 32
avl.traverse_inorder()
# height 1
print('height', avl.get_height(avl.root))