[Python] ABC015 C メモ化再帰

問題 C - 高橋くんのバグ探し 回答 深さ優先探索で解きます。 スライドp14以後。 N, K = map(int, input().split()) T = def dfs(num_q, value=0): if n...

この問題ではメモ化再帰を使う必要はないのですが、前述の問題を、練習のためメモ化再帰を使って解きなおします。

問題

C – 高橋くんのバグ探し

回答

defaultdict

defaultdict をメモに使ってみます。

.format() 部分が読みにくいですが、メモの要素数を意識することなくメモを作成できるので、楽です。

defaultdict

import collections

N, K = map(int, input().split())
T = [list(map(int, input().split())) for _ in range(N)]

# メモ化再帰
memo = collections.defaultdict(bool)

def dfs(num_q, value=0):
    if '{}_{}'.format(num_q, value) in memo.keys():
        return memo['{}_{}'.format(num_q, value)]

    if num_q == N:
        if value == 0:
            return True
        return False

    for k in range(K):
        if dfs(num_q+1, value^T[num_q][k]):
            memo['{}_{}'.format(num_q, value)] = True
            return True

    memo['{}_{}'.format(num_q, value)] = False
    return False


if dfs(0):
    print('Found')
else:
    print('Nothing')
        

リスト

答えのメモ用と計算済みフラグ用で二つのリストを用意します。

N, K = map(int, input().split())
T = [list(map(int, input().split())) for _ in range(N)]

# メモ化再帰
memo = [[False] * 128 for _ in range(N+1)]
done = [[False] * 128 for _ in range(N+1)]

def dfs(num_q, value=0):
    if done[num_q][value]:
        return memo[num_q][value]

    if num_q == N:
        if value == 0:
            return True
        return False

    for k in range(K):
        if dfs(num_q+1, value^T[num_q][k]):
            done[num_q][value] = True
            memo[num_q][value] = True
            return True

    done[num_q][value] = True
    memo[num_q][value] = False
    return False


if dfs(0):
    print('Found')
else:
    print('Nothing')
        

lru_cache

@functools.lru_cache(user_functionを使いメモ化再帰を行います。

from functools import lru_cache

N, K = map(int, input().split())
T = [list(map(int, input().split())) for _ in range(N)]


@lru_cache(maxsize=2**20)
def dfs(num_q, value=0):
    if num_q == N:
        if value == 0:
            return True
        return False

    for k in range(K):
        if dfs(num_q+1, value^T[num_q][k]):
            return True
    return False


if dfs(0):
    print('Found')
else:
    print('Nothing')