bitDP(集合を用いたDP)について

ABC054のC - One-stroke PathをbitDPと呼ばれる手法で解いたので、bitDPについて書きます。

まずは問題を再掲します。

C: One-stroke Path - AtCoder Beginner Contest 054 | AtCoder

問題概要

N頂点の無向グラフが与えられる。グラフに二重辺と自己ループは無い。
頂点0からスタートして全ての頂点を1度だけ訪れるパスは何通りあるか?

制約

  • 2 <= N <= 8
  • 1 <= M <= N(N-1)/2

bitDPを用いた解法

bitDPはビット演算を用いた動的計画法の一種です。
そしてDPテーブルにどのようなキーを持たせるかということなんですが、この場合だとS = (既に訪れた頂点の集合)、v = (最後に訪れた頂点)の2つを使います。
このdp[S][v]に求めたいものを下から求めていくっていうのがbitDPですね。今回の問題の場合だと
dp[S][v] = (頂点0から出発し、頂点集合Sを全て訪れるpathのうち頂点vが最後になるようなpathの総数)
とします。そして、最終的に欲しいのは頂点0から出発して全ての点を訪れるpathの総数なので
sum(dp[{0,1,2,...,N-1}][u] for u in range(N))
とすることで求められます。

問題は漸化式、部分問題の結果をどう生かすのかということですね。
今回の場合だと
dp[S][v] = sum(dp[S - {v}][u]| u in S - {v} and Adj[u][v])
という漸化式によって求めることができます。
Sを全てたどってvを最後に訪れるpathの総数は、S-{v}を全てたどって、vとつながってる点uを最後に回るpathの総数の総和で求められる、ということですね。
考え方自体は割とシンプルです。

これだけだとぱっと見ビット演算は関係無さそうです。これをなんでbitDPというのかというと、集合に2進数を対応付けるからなんですね。
N個の集合X = {0,1,2,...,N-1}の部分集合っていうのは、各要素を「取る」or「取らない」で表せるのでXの任意の部分集合はNbitの2進数に対応付けられます。
例えばX = {0,1,2}なら対応は以下のようになります。

Xの部分集合 対応する2進数 10進表記
{} 000 0
{0} 001 1
{1} 010 2
{0,1} 011 3
{2} 100 4
{0,2} 101 5
{1,2} 110 6
{0,1,2} 111 7

このように確かに部分集合とNbitの2進数が一対一に対応付けられます。
こうやって集合を2進数で表記すれば、後は集合の演算とかも全部bit演算でやることができます。だからbitDPって呼ばれてるようです。

ただ肝は集合をbitで表すことでは無く、dpテーブルのキーに集合を用いるってとこなのでbitDPって呼び方は変で、集合に対するDPって呼んだ方が分かりやすいって意見もありました。確かにそんな気もします。

ともあれ、このbitDPを用いてABC054のC問題を解いたソースコードがこちら
計算量はO(2^N * N * N) = O(2^N * N^2)となり、pythonでもN = 15, 16ぐらいまでならいけます

    def do_dp(N, Adj):
        univ = 2**N - 1
        dp = [[0]*N for i in range(univ + 1)]
        dp[1][0] = 1
     
        for S in range(2, univ + 1):
            for v in range(N):
                S2 = S & (univ ^ (1 << v))
                for u in range(N):
                    if ((1 << u) & S2) and Adj[u][v]:
                        dp[S][v] += dp[S2][u]

        ans = sum(dp[univ][u] for u in range(1, N))
     
        return ans
     
    def solve():
        N, M = map(int, input().split())
        Adj = [[0]*N for i in range(N)]
     
        for i in range(M):
            a, b = map(int, input().split())
            a -= 1
            b -= 1
            Adj[a][b] = 1
            Adj[b][a] = 1
     
        ans = do_dp(N, Adj)
     
        print(ans)
     
    if __name__ == '__main__':
        solve()

ここではunivが全体集合{0,1,2,...,N-1}に対応しています。
初期値はdp[{0}]{0] = dp[1][0]だけ1にして、後は全て0にしておきます。
forの中で定義しているS2ってのはbit演算を使っていて分かりにくいですが、これがS - {v}に対応しています。
そして、(1 << u) & S2は「S - {v}のなかにuは存在しているか?」を判断しています。
bitDPと聞くと難しそうですが、bit演算を使って何をするのかが分かれば、実装は割とシンプルだと思いました。

せっかくなので他の問題にも応用してみようと思い、この前解いたyukicoderのNo.90 品物の並び替えをbitDPで解いてみることにしました。
★2の問題です。

No.90 品物の並び替え - yukicoder

問題文(コピペ)

ここに0番〜(N-1)番の品物がある。
また、
item1 item2 score
という形式で書かれた得点表がある。
品物を並べた時、item1がitem2よりも前であればscore点を獲得できるという意味である。

得点表が与えられるので、品物を適切に並び替えた時、獲得できる得点を最大化したい。そのときの得点を出力せよ。

注意:LL系の言語だと工夫しないといけないかもしれません。

制約

  • 2<=N<=9
  • 1<=M<=N*(N-1)

解法

想定解は並べ方を全列挙して、その並びの点数を計算して、最大のものを取るという方法です。
この方法だと全列挙にN!かかり、その各並びについて点数を計算する時間がN*(N-1)/2かかりますから、全体でO(N!*N^2)の計算量となります。
この問題は1度解いたことあるんですが、pythonだと時間の制約が厳しくTLE、pypyで何とか通すことができるというレベルでした。
pythonでも全列挙で通している人が結構いたので何かやり方がまずかったのかもしれません)

それでさっきの問題もN!の並び方を全列挙して調べるという方法がbitDPで計算量を削減することができたので、この問題もbitDPで出来そうだなあって思いました。
実際bitDPで解くことは可能で、先ほどのようにS = (既に訪れた頂点集合)、v = (最後に訪れた頂点)として
dp[S][v] = (Sに含まれる頂点を全て訪問し、頂点vを最後に訪れるようなpathのうち、点数が最大となるpathの点数)
とします。品物を並べる問題でしたが、品物を頂点だと思って、並びを各頂点をその順で訪問していくpathだと思えばグラフの問題に帰着できます。
例えば、3 2 1 0という並べ方は最初に頂点3からスタートし、順に3→2→1→0と各頂点を1度ずつ訪問していくpathだと思うことができます。

そして漸化式は、
dp[S][v] = max(dp[S - {v}][u]|u in S - {v}) + sum(scores[u][v]|u in S - {v})
という感じで求めることができます。Sからvを除いた集合を全てたどるpathのうち、最も大きくなるものを取ってきて、そのスコアに(u, v)のスコアを足せばいいという感じですね。

最終的に求めたいものは
max(dp[{0,1,2,...,N-1}][u]|u in {0,1,2,...,N-1})
で求めることができます。

というわけでソースコードです。
もともとO(N!*N^2)だったものが、O(2^N * N^2)とかなり計算量が落ちたのでpythonでも楽々通るようになりました。
bitDPの凄さを思い知らされると同時に、階乗の増大スピードの恐ろしさを知ることができました。

def bitDP(N, Mat):
    univ = 2**N - 1
    dp = [[0]*N for i in range(univ + 1)]

    for S in range(univ + 1):
        for v in range(N):
            S2 = S & (univ ^ (1 << v))
            for u in range(N):
                if S2 & (1 << u):
                    dp[S][v] = max(dp[S][v], dp[S2][u])

            for u in range(N):
                if S2 & (1 << u):
                    dp[S][v] += Mat[u][v]

    ans = max(dp[univ][u] for u in range(N))

    return ans

def solve():
    N, M = map(int, input().split())
    Mat = [[0]*N for i in range(N)]

    for i in range(M):
        i1, i2, score = map(int, input().split())
        Mat[i1][i2] = score

    ans = bitDP(N, Mat)

    print(ans)

if __name__ == '__main__':
    solve()

最後に、bitDPを理解するにあたってこちらのスライドを参考にさせていただきました。

www.slideshare.net