部分和問題のメモ化再帰: pythonの不正解コードの考察

記事内に広告が含まれています。

問題文

$N$ 個の整数 $A_0, A_1, … , A_{N-1}$ の中からいくつか選んで $X$ を作ることができますか。

引用元: アルゴ式 Q4.6 部分和問題(再帰-2)

メモ化再帰を使う有名な問題です。本記事のコードの正当性のチェックは、アルゴ式というサイトで行いました。

部分和問題 (再帰-2) | アルゴ式(beta)

メモ化再帰の考え方

再帰関数を以下のように定義します。

rec(i,w) := $A_{i-1}$ までの要素からいくつか選んで w を作れる場合はTrue, そうでない場合はFalse

再帰関数で2種類の引数を管理しているため、メモ memo[i,w] も2種類の引数を用いて管理します。すなわち、二次元のテーブルで管理します。

  • ベースケースは何か?
    • 数列の要素を0個( i=0 )使って総和0( w=0 )を作れるので True
  • 再帰関数の計算式は? -> w>= A[i-1]の場合とw<A[i-1]の2通りの式がある
    • w>= A[i-1]の場合、Trueになるのは以下の2通り
      • rec(i-1, w-A[i-1]) = True : A[i-1]を選択する場合
      • rec(i-1, w) =True : A[i-1] を選択しない場合
    • w<A[i-1]の場合、Trueになるのは以下の1通りのみ
      • rec(i-1, w) =True : A[i-1]を選択しない場合
  • メモは True/False を管理する。($A_{i-1}$ までの要素からいくつか選んで w を作れる場合はTrue, そうでない場合は False )

メモ化再帰コードを実装する際の考え方やテンプレートを以下のサイトでも紹介しています。ご興味があれば是非ご覧ください。

失敗コード

以下のコードを提出しましたが、不正解(TLE)となりました。どこが問題かわかりますか?

import sys
sys.setrecursionlimit(10**6)

N, W = map(int, input().split())
A = list(map(int, input().split()))

memo = [[False]*(W+1) for _ in range(N+1)]
memo[0][0] = True

cnt = 0
def rec(i,w):
    
    if i==0 and w==0:
        return True
    elif i==0 and w!=0:
        return False
    
    #メモの参照:同じ計算を繰り返すことを防ぐ
    if memo[i][w] != False:
        return memo[i][w]
    
    else:
        if w >= A[i-1]:
            #A[i-1]を選択する場合 or A[i-1]を選択しない場合
            memo[i][w] = rec(i-1, w-A[i-1]) or rec(i-1,w)
        else:
            #A[i-1]を選択しない場合
            memo[i][w] = rec(i-1, w)
        
        return memo[i][w]

if rec(N,W):
    print('Yes')
else:
    print('No')

問題点の解説

上記のコードの問題点は、7行目のメモの初期値にあります。

memo = [[False]*(W+1) for _ in range(N+1)]

なぜなら、メモの取りうる値を False / True の2通りだけにした際、

memo[i][w]が再帰計算の結果Falseになったのか?それとも未更新のためFalseなのか?

が判定できません。

そのため、仮にmemo[i][w] は再帰計算が完了して False になったとしても、再度計算されてしまい、メモ化再帰の高速化を実現できなくなっています。

そこで、メモ memo に入れる変数として、以下の3種類を用意すると解決できます。

  1. -1 : 初期値. まだ更新されていないことを表す
  2. False: $X$ を作ることができない
  3. True: $X$ を作ることができる

正解のコード

これまでの考察を踏まえ、メモの初期化部分とメモの参照部分を変えるとうまくいきます。

import sys
sys.setrecursionlimit(10**6)

N, W = map(int, input().split())
A = list(map(int, input().split()))

memo = [[-1]*(W+1) for _ in range(N+1)] #初期値を-1に変更
memo[0][0] = True

cnt = 0
def rec(i,w):
    
    if i==0 and w==0:
        return True
    elif i==0 and w!=0:
        return False
    
    #メモの参照:同じ計算を繰り返すことを防ぐ
    if memo[i][w] != -1:   #Falseではなく-1に変更
        return memo[i][w]
    
    else:
        if w >= A[i-1]:
            #A[i-1]を選択する場合 or A[i-1]を選択しない場合
            memo[i][w] = rec(i-1, w-A[i-1]) or rec(i-1,w)
        else:
            #A[i-1]を選択しない場合
            memo[i][w] = rec(i-1, w)
        
        return memo[i][w]

if rec(N,W):
    print('Yes')
else:
    print('No')

このコードで全問正解にすることができました。

まとめ

メモ化再帰を使う問題の実装において、自分の失敗したコードと問題点について整理しました。

本記事の結論を一言でまとめると、メモが更新/未更新かを判別できるような初期値を用意するべきとなります。

本記事のコードはアルゴ式にて全問正解になることを確認しております。

部分和問題 (再帰-2) | アルゴ式(beta)

また、メモ化再帰に関する参考書籍として以下を挙げておきます。

問題解決力を鍛える!アルゴリズムとデータ構造
競技プログラミング経験が豊富な著者が、「アルゴリズムを自分の道具としたい」という読者に向けて執筆。入門書を標榜しながら、AtCoderの例題、C++のコードが充実。入門書であり実践書でもある、生涯役立

本記事が何かの参考になれば幸いです。

コメント