[Python] ABC007 D 桁DP

ABC007 Dを桁DPを使って解きます。

桁DP/Digit DP

「n以下の整数の処理」を考えるときに、 大きい桁から一桁ずつ数を見ていき、結果を代入するDP配列に、nより小さいことが確定しているかどうかのフラグを含めることで状態を管理する動的計画法です。

桁DPの解説は、以下が分かりやすいです。

桁DPの痒いところに手が届く解説

Digit DP 入門

桁 DP の思想 〜 K 以下の整数を走査するとはどういうことか 〜

簡単な問題から考える

まずは、0から始まる5桁の整数がいくつ存在するのかを、一番上の桁から数えるコードを考えます。

0から99999まであるので、当然100000個存在します。

# 桁数
length = 5

dp = [0 for _ in range(length+1)]
dp[0] = 1

for i in range(length):
    for d in range(10):
        dp[i+1] += dp[i]

# [1, 10, 100, 1000, 10000, 100000]
print(dp)

次に、Digit DP 入門に従って上のコードに「未満フラグ」を導入することで 、以下の簡単な問題を解いてみます。

「\(N = 12345 \) 以下の0と自然数の数を求める」

N = '12345'

# 桁数
length = len(N)

# dp[ 決めた桁数 ][ 未満フラグ ] := 総数
dp = [[0]*2 for _ in range(length+1)]
dp[0][0] = 1
print(dp)

for i in range(length):
    max_digit = int(N[i])
    
    # flag_less: 過去に対応する桁より小さい数を入れたことがある時はTrue=1
    # Trueの時、0から9の数字を使える ex. 12345 の2桁目まで 11*** であれば、
    # 3桁目はどの数字でも良い。
    # Falseの時、その桁の数字が使える最大の数字 ex. 12345 2の2桁目まで 12***であれば、
    # 3桁目は 0 or 1 or 2 or 3 のいずれか。
    for flag_less in range(2):
        range_digit = 9 if flag_less else max_digit
        for d in range(range_digit+1):
            # d が決まると、遷移先の未満フラグも決まる。
            flag_less_next = 0        
            if flag_less==1 or d < max_digit:
                flag_less_next = 1
            dp[i+1][flag_less_next] += dp[i][flag_less]

# [[1, 0], [1, 1], [1, 12], [1, 123], [1, 1234], [1, 12345]]
print(dp)
# 12345以下の0と自然数の数: 12346
print(f'{N}以下の0と自然数の数: {dp[length][0]+dp[length][1]}')

上のコードを利用して以下の問題を考えます。

「\(N = 12345 \) 以下の0と自然数で各桁に5を含まないものの数を求める」

N = '12345'

# 桁数
length = len(N)

# dp[ 決めた桁数 ][ 未満フラグ ] := 総数
dp = [[0]*2 for _ in range(length+1)]
dp[0][0] = 1

for i in range(length):
    max_digit = int(N[i])
    
    for flag_less in range(2):
        range_digit = 9 if flag_less else max_digit
        for d in range(range_digit+1):
            # 5 の時はスキップ
            if d == 5:
                continue
            # d が決まると、遷移先の未満フラグも決まる。
            flag_less_next = 0        
            if flag_less==1 or d < max_digit:
                flag_less_next = 1
            dp[i+1][flag_less_next] += dp[i][flag_less]

# [[1, 0], [1, 1], [1, 11], [1, 102], [1, 922], [0, 8303]]
print(dp)
# 12345以下の0と自然数の数で各桁に5を含まないものの個数: 8303
print(f'{N}以下の0と自然数の数で各桁に5を含まないものの個数: {dp[length][0]+dp[length][1]}')

「\(N = 12345 \) 以下の0と自然数でいずれかの桁に4か9を含むものの数を求める」

N = '12345'

# 桁数
length = len(N)

# dp[ 決めた桁数 ][ 未満フラグ ][4 or 9] := 総数
dp = [[[0]*2 for _ in range(2)] for _ in range(length+1)]
dp[0][0][0] = 1

for i in range(length):
    max_digit = int(N[i])
    for flag_less in range(2):
        for flag_four_or_nine in range(2):
            range_digit = 9 if flag_less else max_digit
            for d in range(range_digit+1):
                flag_less_next = 0
                flag_four_or_nine_next = 0   
                if flag_less==1 or d < max_digit:
                    flag_less_next = 1
                if flag_four_or_nine==1 or d ==4 or d==9:
                    flag_four_or_nine_next = 1
                dp[i+1][flag_less_next][flag_four_or_nine_next] += dp[i][flag_less][flag_four_or_nine]

print(dp[length][0][1]+dp[length][1][1])

問題

D – 禁止された数字

これまでのコードを利用して問題を解きます。

\( [A,B]=\{A,A+1,A+2,…,B\} \) を求めるには、単純に \( [0,A-1] \) を引けば求まります。

def return_nums_of_num_include_49(N):

    length = len(N)
    # dp[ 決めた桁数 ][ 未満フラグ ][4 or 9] := 総数
    dp = [[[0]*2 for _ in range(2)] for _ in range(length+1)]
    dp[0][0][0] = 1

    for i in range(length):
        max_digit = int(N[i])
        for flag_less in range(2):
            for flag_four_or_nine in range(2):
                range_digit = 9 if flag_less else max_digit
                for d in range(range_digit+1):
                    flag_less_next = 0
                    flag_four_or_nine_next = 0   
                    if flag_less==1 or d < max_digit:
                        flag_less_next = 1
                    if flag_four_or_nine==1 or d ==4 or d==9:
                        flag_four_or_nine_next = 1
                    dp[i+1][flag_less_next][flag_four_or_nine_next] += dp[i][flag_less][flag_four_or_nine]
    
    return dp[length][0][1]+dp[length][1][1]

a, b = input().split()
print(return_nums_of_num_include_49(b) - return_nums_of_num_include_49(str(int(a)-1)))