[Python] Educational DP Contest J – Sushi

問題

J – Sushi

期待値DP

期待値DPについて以下分かりやすいです。

確率 DP を極めよう

また、Pythonであるということ以外、メモ化再帰は以下のほぼ写経です。

Educational DP Contest の F ~ J 問題の解説と類題集 J 問題 – Sushi

回答

メモ化再帰

PyPyでもTLEで間に合わない。

import sys
# input処理を高速化する
input = sys.stdin.readline
# 許容する再帰処理の回数を変更
sys.setrecursionlimit(300*300*300+10)

def main():
    N = int(input())
    lst_a = list(map(int, input().split()))
    one, two, three = lst_a.count(1), lst_a.count(2), lst_a.count(3)

    # dp[i][j][k] := 残り1個の皿がi枚、2個の皿がj枚、3個の皿がk枚の状態から、
    # 寿司をすべてなくすのに必要な操作回数の期待値
    # dpは-1で初期化
    dp = [[[-1.0] * (N+1) for _ in range(N+1)] for _ in range(N+1)]

    def rec(i, j, k):
        if dp[i][j][k] >= 0:
            return dp[i][j][k]
        if (i == 0 and j == 0 and k == 0):
            return 0.0
        
        res = 0.0
        if i > 0:
            res += rec(i-1, j, k) * i
        if j > 0:
            res += rec(i+1, j-1, k) * j
        if k > 0:
            res += rec(i, j+1, k-1) * k
        res += N
        res *= 1.0 / (i + j + k)
        dp[i][j][k] = res
        return dp[i][j][k]


    print(rec(one, two, three))

main()

DP

メモ化再帰をDPに書き直したもの。これはPyPyであれば間に合う。

import sys

# input処理を高速化する
input = sys.stdin.readline

def main():
    N = int(input())
    lst_a = list(map(int, input().split()))
    one, two, three = lst_a.count(1), lst_a.count(2), lst_a.count(3)

    # dp[i][j][k] := 残り1個の皿がi枚、2個の皿がj枚、3個の皿がk枚の状態から、
    # 寿司をすべてなくすのに必要な操作回数の期待値
    dp = [[[0.0] * (N+1) for _ in range(N+1)] for _ in range(N+1)]

    for k in range(three + 1):
        for j in range(two + three + 1  - k):
            for i in range(one + two + three + 1 - k -j):
                if i == 0 and j == 0 and k == 0:
                    continue
                tmp = N * 1.0
                if i > 0:
                    tmp += i * dp[i -1][j][k]
                if j > 0:
                    tmp += j * dp[i + 1][j - 1][k]
                if k > 0:
                    tmp += k * dp[i][j + 1][k -1]
                dp[i][j][k] = tmp / (i + j + k)

    print(dp[one][two][three])


main()