[Python] Educational DP Contest D – Knapsack 1

問題

D – Knapsack 1D – Knapsack 1

回答

最初にTLEで間に合わなかった回答

考え方としては合っているはず…。

import sys

# inputを高速化する。
input = sys.stdin.readline

# 入力
N, W = map(int, input().split())
lst_weight = [0 for _ in range(N)]
lst_value = [0 for _ in range(N)]
for i in range(N):
    lst_weight[i], lst_value[i] = map(int, input().split())

def chmax(a, b):
    if a >= b:
        return a
    else:
        return b

dp = [[0 for i in range(W+1)] for j in range(N+1)]

for i in range(N):
    for sum_w in range(W+1):
        # i番目の品物を選ぶ
        if sum_w - lst_weight[i] >= 0:
            dp[i+1][sum_w] = chmax(dp[i+1][sum_w], dp[i][sum_w - lst_weight[i]] + lst_value[i])
        # i番目の品物を選ばない
        dp[i+1][sum_w] = chmax(dp[i+1][sum_w], dp[i][sum_w])

print(dp[N][W])

関数を使うと速くなる

他のコードを参考にして速度を上げるよう努力しました。最終的な分かれ道となったのは、処理を関数内に含めるかどうかです。

下の2つのコードは、処理を関数に含めるかどうかだけの違いですが、関数に含めるとACになります。

関数を使わないとTLE

import sys

input = sys.stdin.readline

N, W = map(int, input().split())

dp = [0] * (W + 1)

for _ in range(N):
    w, v = map(int, input().split())
    for wk in range(W, w-1, -1):
        tv = dp[wk - w] + v
        if tv > dp[wk]:
            dp[wk] = tv

print(dp[-1])

関数を使うとAC

import sys

input = sys.stdin.readline

N, W = map(int, input().split())

dp = [0] * (W + 1)
def knapsack(n, w):
    for _ in range(N):
        w, v = map(int, input().split())
        for wk in range(W, w-1, -1):
            tv = dp[wk - w] + v
            if tv > dp[wk]:
                dp[wk] = tv
    return dp[-1]

print(knapsack(N, W))

関数にすることで、変数がローカル変数として扱われるので処理速度が速くなります。

The Python Wiki PythonSpeed PerformanceTips Local Variables

なぜPythonのコードは関数の中では速くなるか? Why does Python code run faster in a function?

numpy を使った回答

numpy を使うとdpの行列を作って値を更新してくだけなので、 自分にとっては一番わかりやすい。

import sys
import numpy as np

# input処理を高速化する
input = sys.stdin.readline

# 入力
N, W = map(int, input().split())

def knapsack(n, w):
    # dp[i][j] : i番目の品物で重さj以下の価値の最大値
    dp = np.zeros([N+1, W+1], dtype='int64')
    # forは関数の中で回す。
    for i in range(N):
        w, v = map(int, input().split())
        # i+1番目でwより小さい場合は、i番目と変更なし
        dp[i+1][:w] = dp[i][:w]
        dp[i+1][w:] = np.maximum(dp[i][w:], dp[i][:-w] + v)
        # print(dp)
    return dp[-1][-1]

print(knapsack(N, W))