bitDP(集合を用いたDP)について

ABC054のC - One-stroke PathをbitDPと呼ばれる手法で解いたので、bitDPについて書きます。

まずは問題を再掲します。

C - One-stroke Path

問題概要

N頂点の無向グラフが与えられる。グラフに二重辺と自己ループは無い。
頂点0からスタートして全ての頂点を1度だけ訪れるパスは何通りあるか?

制約

  •  2 \leq N \leq 8
  •  1 \leq M \leq N(N-1)/2

bitDPを用いた解法

bitDPはbit演算を用いたDP(動的計画法)の一種です。
一般にDPは「どのような状態を持たせればよいか?」を決めることが肝だと考えられます。
例えば有名なナップサック問題であれば、dp(i, W)などとして「i番目のアイテムまで考えて、ナップサックの容量がWであるという状態」を考えればよい、ということが分かると遷移もさほど難しくなく分かるというものでした。

では、bitDPの場合はどのような状態を持たせればよいのでしょうか?

今回の場合では、S = (既に訪れた頂点の集合), v = (最後に訪れた頂点) としてそのときの解をdp(S, v)として記録します。

集合を状態として持たせる」というところがbitDPの肝ですね。今回の場合だと、

dp(S, v) = (頂点0から出発し、集合Sに含まれる頂点を全て訪れるpathのうち頂点vが最後になるようなpathの総数)

とします。そして、最終的に欲しいのは頂点0から出発して全ての点を訪れるpathの総数なので

 \sum_{u = 1}^{N-1} dp(\{0,1,2,\dots, N-1\}, u)

とすることで求められます。

状態の次は遷移(=漸化式)ですが、これも難しくは無くて
 dp(S, v) = \sum_{u \in adj(v)} dp(S - \{ v \}, u)
と自然に求められます。ここで S - \{ v \}は集合Sから要素vを抜いた集合のことで、adj(v)は頂点vと辺で結ばれている頂点からなる集合です。

ちょっといかつい数式のように見えるかもしれませんが、「集合Sに含まれている頂点を全て通り最後に頂点vを訪れるパスは、集合S-\{v\}に含まれる全ての点を通って最後にuを通った後uvをつなぐ辺を通ってvに行くパスである」ということですね。

これがbitDPの全容です。しかしこれを読んだだけではイマイチbitDPと呼ばれる所以が分かりませんよね。これをなんでbitDPというのかというと、集合に2進数を対応付けるからなんですね。

N個の集合 X = \{0,1,2,...,N-1\}の部分集合っていうのは、各要素を「取る(1)」or「取らない(0)」で表せるのでXの任意の部分集合はN-bitの2進数に対応付けられます。例えばX = {0,1,2}なら対応は以下のようになります。

Xの部分集合 対応する2進数 10進表記
{} 000 0
{0} 001 1
{1} 010 2
{0,1} 011 3
{2} 100 4
{0,2} 101 5
{1,2} 110 6
{0,1,2} 111 7

このように確かに部分集合とN-bitの2進数が一対一に対応付けられます。こうやって集合を2進数で表記すれば、後は集合の演算とかも全部bit演算でやることができます。だからbitDPって呼ばれているようです。

ただ肝は集合をbitで表すことでは無く、状態に集合を用いるということだと思っています。なのでbitDPと呼ぶより、集合に対するDPとか集合DPって言った方が分かりやすいのではないかと思ったりします。(もう界隈ではbitDPで定着してしまったし、私もそれに慣れてしまったのでbitDPと呼んでいますが……)

ともあれ、このbitDPを用いてABC054のC問題を解いたソースコードがこちら。計算量は O(2^N \cdot N^2)となり、pythonでもN = 15, 16ぐらいまでならいけます。

def bit_dp(N, Adj):
    dp = [[0]*N for i in range(1 << N)]
    # dp({0}, 0) = 1 と初期化する
    dp[1][0] = 1

    for S in range(1 << N):
        for v in range(N):
            # v が S に含まれていないときはパスする
            if (S & (1 << v)) == 0:
                continue

            # sub = S - {v}
            sub = S ^ (1 << v)

            for u in range(N):
                # sub に u が含まれており、かつ u と v が辺で結ばれている
                if (sub & (1 << u)) and (Adj[u][v]):
                    dp[S][v] += dp[sub][u]

    ans = sum(dp[(1 << N) - 1][u] for u in range(1, N))
    return ans


def main():
    N, M = map(int, input().split())
    Adj = [[0]*N for i in range(N)]

    for _ in range(M):
        a, b = map(int, input().split())
        Adj[a-1][b-1] = 1
        Adj[b-1][a-1] = 1

    ans = bit_dp(N, Adj)

    print(ans)


if __name__ == '__main__':
    main()

初期値はdp[{0}]{0] = dp[1][0]だけ1にして、後は全て0にしておきます。forの中で定義しているsubってのはbit演算を使っていて分かりにくいですが、これが S - \{v\}に対応しています。そして(sub & (1 << u))は「集合subの中に頂点uが含まれているか」ということを判定しています。bitDPと聞くと難しそうですが、bit演算を使って何をするのかが分かれば、実装は割とシンプルだと思いました。

以下はおまけです。

せっかくなので他の問題にも応用してみようと思い、この前解いたyukicoderのNo.90 品物の並び替えをbitDPで解いてみることにしました。
★2の問題です。

No.90 品物の並び替え - yukicoder

問題文(コピペ)

ここに0番〜(N-1)番の品物がある。
また、
item1 item2 score
という形式で書かれた得点表がある。
品物を並べた時、item1がitem2よりも前であればscore点を獲得できるという意味である。

得点表が与えられるので、品物を適切に並び替えた時、獲得できる得点を最大化したい。そのときの得点を出力せよ。

注意:LL系の言語だと工夫しないといけないかもしれません。

制約

  • 2<=N<=9
  • 1<=M<=N*(N-1)

解法

想定解は並べ方を全列挙して、その並びの点数を計算して、最大のものを取るという方法です。
この方法だと全列挙にN!かかり、その各並びについて点数を計算する時間がN*(N-1)/2かかりますから、全体で O(N! \cdot N^2)の計算量となります。
この問題は1度解いたことあるんですが、pythonだと時間の制約が厳しくTLE、pypyで何とか通すことができるというレベルでした。
pythonでも全列挙で通している人が結構いたので何かやり方がまずかったのかもしれません)

それでさっきの問題もN!の並び方を全列挙して調べるという方法がbitDPで計算量を削減することができたので、この問題もbitDPで出来そうだなあって思いました。
実際bitDPで解くことは可能で、先ほどのようにS = (既に訪れた頂点集合)、v = (最後に訪れた頂点)として
dp[S][v] = (Sに含まれる頂点を全て訪問し、頂点vを最後に訪れるようなpathのうち、点数が最大となるpathの点数)
とします。品物を並べる問題でしたが、品物を頂点だと思って、並びを各頂点をその順で訪問していくpathだと思えばグラフの問題に帰着できます。
例えば、3 2 1 0という並べ方は最初に頂点3からスタートし、順に3→2→1→0と各頂点を1度ずつ訪問していくpathだと思うことができます。

そして漸化式は、
dp[S][v] = max(dp[S - {v}][u]|u in S - {v}) + sum(scores[u][v]|u in S - {v})
という感じで求めることができます。Sからvを除いた集合を全てたどるpathのうち、最も大きくなるものを取ってきて、そのスコアに(u, v)のスコアを足せばいいという感じですね。

最終的に求めたいものは
max(dp[{0,1,2,...,N-1}][u]|u in {0,1,2,...,N-1})
で求めることができます。

というわけでソースコードです。
もともとO(N!*N^2)だったものが、O(2^N * N^2)とかなり計算量が落ちたのでpythonでも楽々通るようになりました。
bitDPの凄さを思い知らされると同時に、階乗の増大スピードの恐ろしさを知ることができました。

def bitDP(N, Mat):
    univ = 2**N - 1
    dp = [[0]*N for i in range(univ + 1)]

    for S in range(univ + 1):
        for v in range(N):
            S2 = S & (univ ^ (1 << v))
            for u in range(N):
                if S2 & (1 << u):
                    dp[S][v] = max(dp[S][v], dp[S2][u])

            for u in range(N):
                if S2 & (1 << u):
                    dp[S][v] += Mat[u][v]

    ans = max(dp[univ][u] for u in range(N))

    return ans

def solve():
    N, M = map(int, input().split())
    Mat = [[0]*N for i in range(N)]

    for i in range(M):
        i1, i2, score = map(int, input().split())
        Mat[i1][i2] = score

    ans = bitDP(N, Mat)

    print(ans)

if __name__ == '__main__':
    solve()

最後に、bitDPを理解するにあたってこちらのスライドを参考にさせていただきました。

www.slideshare.net