[Python] 連結リストで中心のノードを探す

Python で、連結リストの中心にあるノードを探します。

連結リストは、以下の実装を使います。

連結リストを Python で実装します。 ノードのクラス 連結リストのそれぞれのノードは、自身のデータと、次のノードを指すリンクを持ちます。 class Node(object): def __init__(self, data...

単純に考えると、先頭からポインタをたどり最後まで一度移動して連結リストの長さを取得し、その後その半分の位置まで進めば、中心に辿り着きます。

def get_middle_node(self):
    current_node = self.head
    size = 0
    while current_node is not None:
        size += 1
        current_node = current_node.next_node

    middle = size // 2
    current_node = self.head        
    for _ in range(middle):
        current_node = current_node.next_node

    return current_node.data

連結リストに要素数の属性を持たせノードの追加や削除のたびに増減させることで、要素数に\( O(1) \)でアクセスできるようになるので、ループが1回だけになりより速く実行できます。

def get_middle_node_use_length(self):
    middle = self.length // 2
    current_node = self.head

    for _ in range(middle):
        current_node = current_node.next_node

    return current_node.data

ポインタを2つ使い、片方のポインタを倍の速さで進めると、速いポインタが最後に到達した時点で遅いポインタが中心にある、と考えることでも、ループを一回だけにすることができます。

def get_middle_node_use_two_pointers(self):
    slow_pointer = self.head
    fast_pointer = self.head
    
    while fast_pointer.next_node \
        and fast_pointer.next_node.next_node:
        fast_pointer = fast_pointer.next_node.next_node
        slow_pointer = slow_pointer.next_node

    return slow_pointer.data

今回のコードをまとめます。

class Node(object):

    def __init__(self, data):
        self.data = data
        self.next_node = None


class LinkedList(object):
    
    def __init__(self):
        self.head = None
        self.length = 0
    
    # O(1)
    def insert_first(self, data):

        new_node = Node(data)
        self.length += 1

        if not self.head:
            self.head = new_node
        else:
            new_node.next_node = self.head
            self.head = new_node

    # O(N)
    def __str__(self):
        elems = ''
        current_node = self.head
        while current_node:
            if current_node.next_node is None:
                elems += str(current_node.data)
            else:
                elems += str(current_node.data) + ', '
            current_node = current_node.next_node
        return '[' + elems +']'

    # O(1)
    def __len__(self):
        return self.length
    
    # O(N)
    def insert_last(self, data):

        if self.head is None:
            self.insert_first(data)
        else:
            new_node = Node(data)
            self.length += 1
            current_node = self.head

            while current_node.next_node is not None:
                current_node = current_node.next_node
            
            current_node.next_node = new_node

    # 追加
    # O(N)
    def remove(self, data):

        if self.head is None:
            print('リストは空です。')
            return

        current_node = self.head
        previous_node = None

        while current_node.data != data:
            if current_node.next_node is None:
                print('削除に該当するデータがありません。')
                return 
            previous_node = current_node
            current_node = current_node.next_node
        
        # 削除ノードが先頭
        if previous_node is None:
            self.head = current_node.next_node
            self.length -= 1
        else:
            previous_node.next_node = current_node.next_node
            self.length -= 1

    def get_middle_node(self):
        current_node = self.head
        size = 0
        while current_node is not None:
            size += 1
            current_node = current_node.next_node

        middle = size // 2
        current_node = self.head        
        for _ in range(middle):
            current_node = current_node.next_node

        return current_node.data
        
    def get_middle_node_use_length(self):
        middle = self.length // 2
        current_node = self.head

        for _ in range(middle):
            current_node = current_node.next_node

        return current_node.data

    def get_middle_node_use_two_pointers(self):
        slow_pointer = self.head
        fast_pointer = self.head
        
        while fast_pointer.next_node \
            and fast_pointer.next_node.next_node:
            fast_pointer = fast_pointer.next_node.next_node
            slow_pointer = slow_pointer.next_node

        return slow_pointer.data


if __name__ == '__main__':
    linked_list = LinkedList()

    linked_list.insert_first(1)
    linked_list.insert_first(2)
    linked_list.insert_first(3)
    linked_list.insert_first(4)
    linked_list.insert_first(5)
    linked_list.insert_first(6)
    linked_list.insert_first(7)   
    # [7, 6, 5, 4, 3, 2, 1] 
    print(linked_list)
    # 4
    print(linked_list.get_middle_node())
    # 4
    print(linked_list.get_middle_node_use_length())
    # 4
    print(linked_list.get_middle_node_use_two_pointers())