読者です 読者をやめる 読者になる 読者になる

GCJ Qual.2017 - B

GCJのRound1が全て終わりました。私はR1B, R1Cに参加しましたが、どちらも1000位以内には入れず、残念ながらここで敗退となりました。
来年まで競プロをやっていたら、次はR2に進出してみたいですね。

それで今回はQualification RoundのB問題についての記事です。
なんで今更という感じですが、この問題はいろんな解き方があって面白かったです。

Problem B. Tidy Numbers

問題概要

  • ある整数を10進表記したとき、左から広義の昇順で数字が並んでいる整数をtidy numberと呼ぶ
  • 例えば、8, 123, 33348, 3777, 5555 などはtidy numberである
  • 例えば、10, 33321, 57993 などはtidy numberではない
  • 1からNまでの整数を順に書いたとき、最後に書いたtidy numberは何か?

感想

色んな解き方があって面白い問題だと思いました。

まずSmallはN <= 1,000と小さいので、クエリごとにN, N-1, N-2, ..., 2, 1と逆にたどっていってtidy numberを見つけたらループを抜けて出力、という解法で間に合います
1ケースごとにO(NlogN)かかるので、全体でO(TNlogN)ですね

ちょっと工夫して前処理として先に動的計画法で各Nに対応する最後に書いたtidy numberを格納する配列を作っておけば前処理でO(N_max log(N_max)), 各クエリにO(1)で答えられるのでO(T + N_maxlog(N_max))となります

import sys

def solve():
    T = int(sys.stdin.readline())

    lts = [0] * (1000 + 1)

    for i in range(1000 + 1):
        i_st = str(i)
        flag = True

        for j in range(len(i_st) - 1):
            if i_st[j] > i_st[j + 1]:
                flag = False
                break

        if flag:
            lts[i] = i
        else:
            lts[i] = lts[i - 1]

    for tc in range(T):
        N = int(sys.stdin.readline())

        print('Case #{}: {}'.format(tc + 1, lts[N]))

if __name__ == '__main__':
    solve()

それでLargeですが何通りか解法があるようです。

まず実際に私も使った解法で、greedyにやるという方法があります。
例を作ったりsmallの出力を見たりしていると、どっかから先を999...9とするといいということが分かります。

例えば、N = 459527462 とします。
すると、上から桁を見ていくと、459までは昇順になってますが、4595というところで昇順でなくなっています。
459...で始まるtidyは最小でも459999...9ですから、これはNより大きくなってしまいます。
なので、459から始まるN以下のtidyは存在しないことが分かります。
458...で始めると458999...9と後ろに全て9を並べたのはtidyになり、これはN以下の最大のtidyです。
だから一般的には昇順でなくなったところの1個前の数字を-1して、そこから後ろを全部9にするとよさそうだなと思います。

基本的にこれでよさそうですが、例えばN = 122221 みたいなケースがちょっと問題です。
さっきみたいにやると122219とする、ということになっちゃいますがこれはtidyではないですね。
なぜこうなったかというと、2222とここは増加していないからですね。
だからこの場合は狭義で増加しているところまで戻って、そこを-1して、その後ろを全部9にします。
この場合だと12のとこはちゃんと増えてるので、12の2を-1して後ろを9にすると119999になります。
実装としては左から見ていくときにこの「ちゃんと増えるところ」を更新していきながら、減少するところを見つけたら、「ちゃんと増えるところ」を-1して、その後ろは全部9にするという処理をすればいいということになります。

実際に書いたコードはこんな感じです。
これでやると'1000'が'0999'になったりしてこのままキャストすると変なことになる言語もあるかもしれません。
そういう場合は先頭が0かどうかで場合分けするといいと思います。実際私は先頭が0かどうかで場合分けしてましたが、そのままキャストしても大丈夫だと後で気づきました。

import sys

def solve():
    T = int(sys.stdin.readline())
    
    for tc in range(T):
        N = int(sys.stdin.readline())
        ans = get_last_tidy(N)

        print('Case #{}: {}'.format(tc + 1, ans))

def get_last_tidy(m):
    ml = [int(i) for i in str(m)]
    p = 0

    for j in range(len(ml) - 1):
        if ml[j] < ml[j + 1]:
            p = j + 1
        elif ml[j] > ml[j + 1]:
            for k in range(p + 1, len(ml)):
                ml[k] = 9

            ml[p] -= 1

            break

    res = ''.join([str(i) for i in ml])
    res = int(res)

    return res

if __name__ == '__main__':
    solve()

2個目の解き方は何と全列挙です。
一見無謀ですが、EditorialにあるようにL桁以下の全てのtidy numberは「9個のボールをL + 1個の箱に入れる方法」に1:1に対応付けられることが分かります。
例えば、134457ならば

1 3 4 4 5 7 #
1 2 1 0 1 2 2

というように、「前の桁からいくつ増えたか?」をボールの数で表現できる(最初は0から、余ったボールは一番後ろの箱に入れる)ことが分かります。
故にL桁以下の全てのtidy numbersの総数は重複組合せH(L + 1, 9)で求められることとなり、最大ケースでもL = 18なのでH(19, 9) = C(27, 9) = 4,686,825通りでこれを全列挙してもさほど時間はかからないということが分かります。
故に、前処理で18桁以下の全てのtidy numberを昇順で列挙しておいて、それを配列に持っておけば各クエリごとに二分探索すればN以下の最大のtidy numberを十分速く見つけることが可能であるということになります。

これを実装するとなると、まず18桁以下の全てのtidy numbersを列挙するプログラムを作っておいて

def dfs(cur_str, cur_dig, dig_left, tidys):
    if dig_left == 0:
        tidys.append(int(cur_str))
        return

    for i in range(cur_dig, 10):
        dfs(cur_str + str(i), i, dig_left - 1, tidys)

tidys = []
n = 18

dfs('', 0, n, tidys)

print(*tidys)

これを実行して適当なファイルに書きだしておき、

python make_tidynums.py > tidy_nums.txt

最後にLargeのinputファイルをダウンロードして次のプログラムを走らせるという感じですかね。
こういうのはあまり慣れていないので、もうちょっといい方法があるかもしれません。

import sys
import bisect

def solve():
    with open('tidy_nums.txt', 'r') as f:
        tidys = [int(i) for i in f.readline().split()]

    T = int(sys.stdin.readline())

    for tc in range(T):
        N = int(sys.stdin.readline())

        j = bisect.bisect(tidys, N)

        ans = tidys[j - 1]

        print('Case #{}: {}'.format(tc + 1, ans))

if __name__ == '__main__':
    solve()

3つ目の解き方は二分探索からのアプローチです。
ある数Nが与えられたとき、N以下で最大のtidy numberを見つけよというのが元の問題でした。
これを解くのは難しいですが「N以上の整数で最小のtidy numberを求めよ」という問題は簡単に解けます。
例えば、N = 12321 なら 12333 とすればいいと見ただけですぐ分かります。
一般に左から見て行って下がるところがあれば、そこから先は1個前の数字をズラーっと並べればいいということが分かります。
そしてこれが分かると「ある数が与えられたとき、それ以上の最小のtidy numberがNを超えるかどうか」という判定問題が解けます。
そしてこれはどこかで真偽が入れ替わる境界があるので、この境界は二分探索で速く見つけられます。
というわけでこの方法を使うとO((logN)^2)程度の計算量で解くことができる、という感じでした。

import sys

def solve():
    T = int(sys.stdin.readline())
    
    for tc in range(T):
        N = int(sys.stdin.readline())

        btm = 1
        top = N + 1

        while top - btm > 1:
            mid = (top + btm) // 2
            m = get_next_tidy(mid)

            if m <= N:
                btm = mid
            else:
                top = mid

        ans = get_next_tidy(btm)

        print('Case #{}: {}'.format(tc + 1, ans))

def get_next_tidy(mid):
    ms = [int(i) for i in str(mid)]

    for i in range(len(ms) - 1):
        if ms[i] > ms[i + 1]:
            for j in range(i + 1, len(ms)):
                ms[j] = ms[i]
            break

    res = sum(ms[i] * 10**(len(ms) - 1 - i) for i in range(len(ms)))

    return res

if __name__ == '__main__':
    solve()

と、ここまでがEditorialで紹介されていた3つの解法ですね。
これだけでもいろんな解き方があることに気づかせてくれて面白いなあと思ったんですが、さらに面白い解き方があるようです。
ここに引用させていただきます。

1, 11, 111, 1111, 11111, ... のように「1をk個並べた整数」のことをレピュニットというらしいです。
詳しくはwikipediaを参照してください。
レピュニット - Wikipedia

ここで、k桁のレピュニットをR_kと表すことにすると、高々L桁のtidy numberは

c_1 R_1 + c_2 R_2 + \dots + c_L R_L
と表現することができます。ここで、0 \leq c_i \leq 9\sum c_i \leq 9を満たします。
このように表現できることが分かると、元の問題は、与えられた正整数Nについて
 \sum_{i = 1}^{i = L} c_i R_i \leq N
を満たす範囲で\sum_{i = 1}^{i = L} c_i R_iを最大化せよという問題に換言できることが分かります。

そしてこれは桁が大きいレピュニットをNから取れるだけ取っていく、という貪欲法で最大化できることが分かるのでコードにすると以下のような形でO(logN)で解くことができるという感じらしいです。
k桁のレピュニットは一般的にR_k = (10^k - 1) / 9 と求められるので以下のように漸化式を使って構成しなくてもよいです。

import sys

def solve():
    T = int(sys.stdin.readline())

    repunit = [0] * 20
    repunit[1] = 1

    for i in range(2, 20):
        repunit[i] = 10**(i-1) + repunit[i - 1]

    for tc in range(T):
        N = int(sys.stdin.readline())
        ans = 0
        lim = 9

        for i in range(19, 0, -1):
            take = min(N // repunit[i], lim)
            ans += repunit[i] * take
            lim -= take
            N -= repunit[i] * take

        print('Case #{}: {}'.format(tc + 1, ans))

if __name__ == '__main__':
    solve()

と、こんな感じでしょうか。凄くエレガントな解法で上手いなあと感心しました。
このようにいろんな解法が考えられる問題は面白くてよいなあと思いました。
他にも桁DPをやるというtweetもちらほら見かけたのですが、桁DPがどういうものかよく分かってないのでそれについては分かりませんでした。