[Python] ABC015 D メモ化再帰 100点

問題 D - 高橋くんの苦悩 回答 TLE で 0点ですが、まずは再帰的に行う全探索を考えます。 p32 です。 W = int(input()) N, K = map(int, input().split()) AB = de...

問題

D – 高橋くんの苦悩

回答

defaultdict

defaultdict を使ってみましたが、残念ながらTLEでした。

import collections

W = int(input())
N, K = map(int, input().split())
AB = [list(map(int, input().split())) for _ in range(N)]

dp = collections.defaultdict(int)

def dfs(used_w, used_n, now):
    """
    used_w 今まで選んだ合計幅
    used_n 今まで選んだ合計数
    now 今調べる番号
    """
    # メモ化再帰
    if '{}_{}_{}'.format(used_w, used_n, now) in dp.keys():
        return dp['{}_{}_{}'.format(used_w, used_n, now)]

    # 終了条件
    if now == N:
        dp['{}_{}_{}'.format(used_w, used_n, now)] = 0
        return 0
    if used_w >= W or used_n >= K:
        dp['{}_{}_{}'.format(used_w, used_n, now)] = dfs(used_w, used_n, now+1)
        return dfs(used_w, used_n, now+1)

    # 終了条件を満たさない場合
    a = AB[now][0] # 幅
    b = AB[now][1] # 重要度
    
    # 更新できる場合  
    if used_w + a <= W and used_n + 1 <= K:
        val_used = dfs(used_w + a, used_n + 1, now+1) + b
        val_not_used = dfs(used_w, used_n, now+1)
        dp['{}_{}_{}'.format(used_w, used_n, now)] = max(val_used, val_not_used)
        return max(val_used, val_not_used)    
    
    # 更新できない場合
    dp['{}_{}_{}'.format(used_w, used_n, now)] = dfs(used_w, used_n, now+1)
    return dfs(used_w, used_n, now+1)

print(dfs(0, 0, 0))

リスト

DPの配列にリストを使ってみます。

Python では TLE、PyPy では MLE でした。

W = int(input())
N, K = map(int, input().split())
AB = [list(map(int, input().split())) for _ in range(N)]

dp = [[[-1]*(N+1) for _ in range(K+1)] for _ in range(W+1)]

def dfs(used_w, used_n, now):
    """
    used_w 今まで選んだ合計幅
    used_n 今まで選んだ合計数
    now 今調べる番号
    """
    # メモ化再帰
    if dp[used_w][used_n][now] >= 0:
        return dp[used_w][used_n][now]

    # 終了条件
    if now == N:
        dp[used_w][used_n][now] = 0
        return dp[used_w][used_n][now]
    if used_w >= W or used_n >= K:
        dp[used_w][used_n][now] = dfs(used_w, used_n, now+1)
        return dp[used_w][used_n][now]

    # 終了条件を満たさない場合
    a = AB[now][0] # 幅
    b = AB[now][1] # 重要度
    
    # 更新できる場合  
    if used_w + a <= W and used_n + 1 <= K:
        val_used = dfs(used_w + a, used_n + 1, now+1) + b
        val_not_used = dfs(used_w, used_n, now+1)
        dp[used_w][used_n][now] = max(val_used, val_not_used)
        return dp[used_w][used_n][now]    
    
    # 更新できない場合
    dp[used_w][used_n][now] = dfs(used_w, used_n, now+1)
    return dp[used_w][used_n][now]

print(dfs(0, 0, 0))

lru_cache

@functools.lru_cache(user_functionでメモ化再帰を行ってみますが、残念ながらTLE。

from functools import lru_cache

W = int(input())
N, K = map(int, input().split())
AB = [list(map(int, input().split())) for _ in range(N)]


@lru_cache(maxsize=2**30)
def dfs(used_w, used_n, now):
    """
    used_w 今まで選んだ合計幅
    used_n 今まで選んだ合計数
    now 今調べる番号
    """
    # 終了条件
    if now == N:
        return 0
    if used_w >= W or used_n >= K:
        return dfs(used_w, used_n, now+1)

    # 終了条件を満たさない場合
    a = AB[now][0] # 幅
    b = AB[now][1] # 重要度
    # 更新できる場合  
    if used_w + a <= W and used_n + 1 <= K:
        val_used = dfs(used_w + a, used_n + 1, now+1) + b
        val_not_used = dfs(used_w, used_n, now+1)
        return max(val_used, val_not_used)    
    # 更新できない場合
    return dfs(used_w, used_n, now+1)

print(dfs(0, 0, 0))

DPにリストを使って PyPy で回したのがMLEで一番惜しかったので、これを array を使うことでメモリを節約して PyPy で回してみると、AC でした。

import array

W = int(input())
N, K = map(int, input().split())
AB = [list(map(int, input().split())) for _ in range(N)]

dp = [[array.array('i', [-1]*(N+1)) for _ in range(K+1)] for _ in range(W+1)]

def dfs(used_w, used_n, now):
    """
    used_w 今まで選んだ合計幅
    used_n 今まで選んだ合計数
    now 今調べる番号
    """
    # メモ化再帰
    if dp[used_w][used_n][now] >= 0:
        return dp[used_w][used_n][now]

    # 終了条件
    if now == N:
        dp[used_w][used_n][now] = 0
        return dp[used_w][used_n][now]
    if used_w >= W or used_n >= K:
        dp[used_w][used_n][now] = dfs(used_w, used_n, now+1)
        return dp[used_w][used_n][now]

    # 終了条件を満たさない場合
    a = AB[now][0] # 幅
    b = AB[now][1] # 重要度
    
    # 更新できる場合  
    if used_w + a <= W and used_n + 1 <= K:
        val_used = dfs(used_w + a, used_n + 1, now+1) + b
        val_not_used = dfs(used_w, used_n, now+1)
        dp[used_w][used_n][now] = max(val_used, val_not_used)
        return dp[used_w][used_n][now]    
    
    # 更新できない場合
    dp[used_w][used_n][now] = dfs(used_w, used_n, now+1)
    return dp[used_w][used_n][now]

print(dfs(0, 0, 0))