[データ構造] 三分探索木をPythonで実装

三分探索木 三分探索木(さんぶんたんさくぎ、英:ternary search tree)は、トライ木の各ノードを二分探索木として表現したデータ構造である。 出典: フリー百科事典『ウィキペディア(Wikipedia)』 三分探索木は、トライ木...

Python で三分探索木を実装してみます。

ノードクラス

それぞれのノードは、その文字を持ち、左右と真ん中に子を持ちます。

また、キーとしてその文字列が存在する場合は、値を持ちます。

class Node(object):

    def __init__(self, char):
        self.char = char
        self.left_child = None
        self.middle_child = None
        self.right_child = None
        self.value = None

TernarySearchTree クラス

挿入

値の挿入を考えます。

ヘルパー関数_addを利用して再帰的に処理しますが、場合分けがややこしいです。

_addの引数indexは、現在文字列の何文字目を見ているのかを管理する変数です。

左右の子に進んだ場合は、再度その文字について考えるので、indexは増えません。

また、真ん中に進む場合はその文字が見つかったということで、さらにindexが文字列の最後まで到達していなければ、次の文字を検討するためindex1増やします。

indexが最後まで到達している場合は文字列の最後まで到達した場合になります。

class Node(object):

    def __init__(self, char):
        self.char = char
        self.left_child = None
        self.middle_child = None
        self.right_child = None
        self.value = None


class TernarySearchTree(object):

    def __init__(self):
        self.root = None

    def add(self, key, value):
        self.root = self._add(self.root, key, value, 0)

    def _add(self, node, key, value, index):
        '''
        index: 文字列の何文字目かを管理
        '''
        current_char = key[index]
        # ノードが存在しないときはノードを生成
        if node == None:
            node = Node(current_char)

        # 文字が前の時は左の子に再帰的に関数を適用。
        # 同じ文字について検討するので index は増えない。
        if current_char < node.char:
            node.left_child = self._add(node.left_child, key, value, index)
        # 文字が後の時は右の子に再帰的に関数を適用。
        # 同じ文字について検討するので index は増えない。
        elif current_char > node.char:
            node.right_child = self._add(node.right_child, key, value, index)
        # 文字が前でも後でもない、つまりその文字である
        # インデックスがまだ最後まで到着していない
        # 真ん中の子に再帰的に関数を適用
        elif index < len(key) -1:
            node.middle_child = self._add(node.middle_child, key, value, index+1)
        # 文字が前でも後でもない、つまりその文字である
        # インデックスが最後まで到着している、つまりその文字列
        else:
            node.value = value

        return Node

取得

値の取得も、挿入と同様にヘルパー関数を使い再帰的に行います。

場合分けも挿入と同様です。

class Node(object):

    def __init__(self, char):
        self.char = char
        self.left_child = None
        self.middle_child = None
        self.right_child = None
        self.value = None


class TernarySearchTree(object):

    def __init__(self):
        self.root = None

    def add(self, key, value):
        self.root = self._add(self.root, key, value, 0)

    def _add(self, node, key, value, index):
        '''
        index: 文字列の何文字目かを管理
        '''
        current_char = key[index]
        # ノードが存在しないときはノードを生成
        if node == None:
            node = Node(current_char)

        # 文字が前の時は左の子に再帰的に関数を適用。
        # 同じ文字について検討するので index は増えない。
        if current_char < node.char:
            node.left_child = self._add(node.left_child, key, value, index)
        # 文字が後の時は右の子に再帰的に関数を適用。
        # 同じ文字について検討するので index は増えない。
        elif current_char > node.char:
            node.right_child = self._add(node.right_child, key, value, index)
        # 文字が前でも後でもない、つまりその文字である
        # インデックスがまだ最後まで到着していない
        # 真ん中の子に再帰的に関数を適用
        elif index < len(key) -1:
            node.middle_child = self._add(node.middle_child, key, value, index+1)
        # 文字が前でも後でもない、つまりその文字である
        # インデックスが最後まで到着している、つまりその文字列
        else:
            node.value = value

        return node

    def look_up(self, key):
        node = self._look_up(self.root, key, 0)

        if node is None:
            return None
        return node.value

    def _look_up(self, node, key, index):
        if node is None:
            return None

        current_char = key[index]

        if current_char < node.char:
            return self._look_up(node.left_child, key, index)
        elif current_char > node.char:
            return self._look_up(node.right_child, key, index)
        elif index < len(key) - 1:
            return self._look_up(node.middle_child, key, index+1)
        else:
            return node

テスト

以下のような木を想定して、動作するか簡単にテストします。

# "cute","cup","at","as","he","us" and "i" を示す。
          c
        / | \
       a  u  h
       |  |  | \
       t  t  e  u
     /  / |   / |
    s  p  e  i  s

class Node(object):

    def __init__(self, char):
        self.char = char
        self.left_child = None
        self.middle_child = None
        self.right_child = None
        self.value = None


class TernarySearchTree(object):

    def __init__(self):
        self.root = None

    def add(self, key, value):
        self.root = self._add(self.root, key, value, 0)

    def _add(self, node, key, value, index):
        '''
        index: 文字列の何文字目かを管理
        '''
        current_char = key[index]
        # ノードが存在しないときはノードを生成
        if node == None:
            node = Node(current_char)

        # 文字が前の時は左の子に再帰的に関数を適用。
        # 同じ文字について検討するので index は増えない。
        if current_char < node.char:
            node.left_child = self._add(node.left_child, key, value, index)
        # 文字が後の時は右の子に再帰的に関数を適用。
        # 同じ文字について検討するので index は増えない。
        elif current_char > node.char:
            node.right_child = self._add(node.right_child, key, value, index)
        # 文字が前でも後でもない、つまりその文字である
        # インデックスがまだ最後まで到着していない
        # 真ん中の子に再帰的に関数を適用
        elif index < len(key) -1:
            node.middle_child = self._add(node.middle_child, key, value, index+1)
        # 文字が前でも後でもない、つまりその文字である
        # インデックスが最後まで到着している、つまりその文字列
        else:
            node.value = value

        return node

    def look_up(self, key):
        node = self._look_up(self.root, key, 0)

        if node is None:
            return None
        return node.value

    def _look_up(self, node, key, index):
        if node is None:
            return None

        current_char = key[index]

        if current_char < node.char:
            return self._look_up(node.left_child, key, index)
        elif current_char > node.char:
            return self._look_up(node.right_child, key, index)
        elif index < len(key) - 1:
            return self._look_up(node.middle_child, key, index+1)
        else:
            return node


if __name__ == '__main__':
    tst = TernarySearchTree()

    tst.add('cute', 1)
    tst.add('cup', 2)
    tst.add('at', 3)
    tst.add('as', 4)
    tst.add('he', 5)
    tst.add('us', 6)
    tst.add('i', 7)

    # 1
    print(tst.look_up('cute'))
    # 2
    print(tst.look_up('cup'))
    # 3
    print(tst.look_up('at'))
    # 5
    print(tst.look_up('he'))
    # 7
    print(tst.look_up('i'))

    # None
    print(tst.look_up('a'))
    print(tst.look_up('cg'))