Codeforces Round #396 (Div. 2)

参加しました、2完でした。

A.o
B.o
C.x

A. Mahmoud and Longest Uncommon Subsequence

Problem - A - Codeforces

問題概要

文字列a, bが与えられる。どちらか一方"のみ"の部分列となる文字列の最大長さを求めよ。

制約

  • 1 <= len(a), len(b) <= 10^5

感想

A問題なのにいきなり結構難しいな、と10分近く悩んでしまった。
けれど、len(a) < len(b)のときはbはaの部分列に絶対ならないのでlen(b)が最長だと分かるし、len(a) > len(b)についても同様。
ではlen(a) == len(b)のときはどうかというと、aとbが完全に一致するときは明らかに-1
aとbが一致しないときは、a自身はbの部分列にならないのでlen(a)が最大

これをまとめると、a == bなら-1, a != bならmax(len(a), len(b))を出力すればよいことになる。計算量はO(N)。

def solve():
    a = input()
    b = input()

    if a == b:
        ans = -1
    else:
        ans = max(len(a), len(b))

    print(ans)

if __name__ == '__main__':
    solve()

B. Mahmoud and a Triangle

Problem - B - Codeforces

問題概要

長さa_iの線分がN本与えられる。
これらの線分の中から3本を選び、非退化な三角形を作れるかどうか判定せよ。
非退化な三角形とは、面積が正の値となる三角形のことである。

例えばa = 2, b = 5, c = 6を3辺に持つ三角形は非退化な三角形であるが、a = 2, b = 5, c = 7を3辺に持つ三角形は線分に潰れてしまっているので退化した三角形である。

制約

  • 3 <= N <= 10^5
  • 1 <= a_i <= 10^9

感想

線分を3本取ってきてそれらの長さをa, b, c(a <= b <= c)としたとき、三角不等式
a + b > c
が成り立つなら、この3本を用いて面積が正の値となる三角形を作れます。成り立たなければ作れません。

なのでこれを素朴に総当たりでやろうとすると、O(bin(N, 3)) = O(N^3)なのでまず総当たりでは無理(……と思ったのですが、ある方法を使えば素朴なやり方でも可能だったようです。これについては後述します)
なのでどうするかというと、まず{a_i}を昇順ソートします。

次にa_0 + a_1 > a_2が成り立っているかを見ます。もし成り立っていれば終わりです。
成り立ってなければソートされているのでa_0 + a_1 > a_j(j > 2)も明らかに成り立っていないことが分かりますのでここはチェックしなくてもよくなってます。
次にa_1 + a_2 > a_3が成り立っているかを見ます。
成り立ってなければ、a_i + a_j > a_k (i < j <= 2, k >= 3)が全て成り立たないことが分かります。
以下同様にやっていって、全て成り立たなければNOとなります。
計算量はソートにかかるO(Nlog(N))と線形に調べていくO(N)の和なので、O(Nlog(N))となります。

実際に書いたコード

def solve():
    n = int(input())
    A = [int(i) for i in input().split()]

    A.sort()

    for i in range(n - 2):
        if A[i] + A[i + 1] > A[i + 2]:
            print('YES')
            return
    else:
        print('NO')

if __name__ == '__main__':
    solve()

これが第一の解法でした。これだと何の面白味も無い感じですが、解説に書いてあるもう一方の解法が面白かったです。

どの3本を取っても三角形を作れないor退化した三角形を作れない数列{a_i}というのはどんなのだろう、っていうことを考えてみます。
これは{a_i}はソートされてるものとすると、上の考察から任意のiについてa_i + a_{i + 1} <= a_{i + 2}となっているときです。
このような数列になっていて、max(a_i)が最も小さくなるような数列{a_i}っていうのは上の不等式で常に=が成り立っていて、かつa_1 = a_2 = 1のとき、つまりこれはフィボナッチ数列です。
故にフィボナッチ数列のn番目の項をfib(n)とすると、どの3本を取っても三角形を作れないor退化した三角形を作れない数列{a_i}はmax(a_i) >= fib(N)が必ず成り立っています。
これは対偶を取ると、max(a_i) < fib(N)ならば数列{a_i}はどの3本を取っても三角形を作れないor退化した三角形を作れない数列「ではない」、つまりある3本を取ると非退化な三角形が作れる数列になっています。
ところで制約条件より、max(a_i) <= 10^9であることが分かっていますから、max(a_i) < fib(45) = 1,134,903,170が常に成り立ちます。
つまり、N >= 45のときは数列の中身を見ること無しに、必ず非退化な三角形が作れる数列であることが分かります。
後はN < 45の場合については、上に書いたO(N^3)の方法で素朴にやっても十分間に合うし、もちろんO(Nlog(N))の方法でやってもいいですね、と言う感じでした。
必要条件を使えばこういう上手い計算量の落とし方があるんだなあ、と思って感動しました。

というわけでN >= 45はYES、N < 45は素朴にやる方法のコード

from itertools import combinations

def solve():
    n = int(input())

    if n >= 45:
        print('YES')
        return

    A = [int(i) for i in input().split()]

    for i, j, k in combinations(range(n), 3):
        c = max(A[i], A[j], A[k])

        if A[i] + A[j] + A[k] - c > c:
            print('YES')
            break
    else:
        print('NO')

if __name__ == '__main__':
    solve()

C. Mahmoud and a Message

Problem - C - Codeforces

問題概要

各アルファベットに「そのアルファベットを含む文字列として許される最大の長さ」が与えられる。
例えば、a = 2なら文字列'a', 'aa'までは許されるが'aaa'、'aaaa'は許されない。
今文字列が与えられるので、この文字列を上の制限に引っかからないよう適当に分割して表現したい。
このとき、以下の問いに答えよ

  • 与えられた文字列の分割の仕方は何通りあるか?
  • 分割したときの部分文字列の中で最大の長さを持つ文字列の長さ
  • 分割回数の最小

制約

  • 1 <= 文字列の長さ <= 1000
  • 1 <= 各アルファベットに与えられる数 <= 1000

感想

何かDPっぽい感じはするけど、どうやって漸化式を立てるのかっていう構想がなかなか上手くできませんでした。
実際コーディングしてからもいろいろ間違えてたりっていうのがあってスムーズに行かなかったですね。
すぐには答えが浮かばない問題に取り組むときってとりあえず分かりやすい例を作って、そっからヒントを得てアルゴリズムを作ってコーディングっていう流れでだいたいやってるんですけど、こういうやり方だと作った例が特殊でその例でしか上手くいかないアルゴリズムとかを組んじゃったりして間違えるんですよね。

解説を読んで改めて見返してみると典型的なDP問題だなあ、という感じがします。こういうのがぱっと解けるようになりたい。

簡潔に説明する気力と能力が無いのでソースコードだけ貼っておきます。

まず解説を見ずに一応自力で解いたコード。
dequeを使って、分割したときの一番後ろの文字列がiとなる奴が何個あるか、っていうのを保管して何かやってます。
最小の分割回数を求めるのは前から貪欲法でやるのでそこだけ処理を分けてます。

import sys
from collections import deque

def get_minsp(message, As):
    min_sp = 1
    cur_len = 0
    capa = 0

    for ch in message:
        index = ord(ch) - ord('a')

        if cur_len == 0:
            cur_len += 1
            capa = As[index]
        else:
            capa = min(capa, As[index])

            if cur_len + 1 <= capa:
                cur_len += 1
            else:
                min_sp += 1
                cur_len = 1
                capa = As[index]

    return min_sp

def solve():
    MOD = 10**9 + 7
    n = int(input())
    message = input()
    As = [int(i) for i in input().split()]

    capas = []

    for ch in message:
        i = ord(ch) - ord('a')
        capas.append(As[i])

    deq = deque()
    cur_len = 0
    max_len = 0
    min_sp = 0
    capa = 0

    for i, capa in enumerate(capas):
        if not deq:
            deq.append(1)
            cur_len = 1
            max_len = 1
        else:
            limit = capa
            lim_len = capa

            for j in range(1, min(i + 1, capa)):
                if min(limit, capas[i - j]) < j + 1:
                    lim_len = j
                    break
                limit = min(limit, capas[i - j])
            else:
                lim_len = min(i + 1, capa)

            deq.appendleft(sum(deq) % MOD)

            if cur_len + 1 <= lim_len:
                cur_len += 1
                max_len = max(cur_len, max_len)
            else:
                num_del = cur_len + 1 - lim_len

                for j in range(num_del):
                    deq.pop()

                cur_len = lim_len

    min_sp = get_minsp(message, As)

    print(sum(deq) % MOD)
    print(max_len)
    print(min_sp)

if __name__ == '__main__':
    solve()

次に解説通りにdpテーブルを作ってやるやり方。
本質的には一緒だけど、微妙に違う感じがします。

(17/05/10/17:56 更新)
今見たらいろいろ分かりにくい感じがしたのでちょっと書き直しました。
今やってみると普通にこっちの方が分かりやすいですね。最初に自力で書いたdequeを使う方法は今見ると「なんだこれは……」って感じがします。何がしたいかは自分で書いたから何となく分かるんですけど、なんかよくこんな方法で出来たなあって思いました。

import sys

mod = 10**9 + 7
inf = 1<<30

def solve():
    n = int(input())
    msg = input()
    clim = [int(i) for i in input().split()]

    dp = [0]*(n + 1)
    dp[0] = 1
    max_len = 0

    for i in range(1, n + 1):
        lim_len = inf
        cur_len = 0

        for j in range(i, 0, -1):
            alpha = ord(msg[j - 1]) - ord('a')
            lim_len = min(lim_len, clim[alpha])
            cur_len += 1

            if cur_len > lim_len:
                break

            max_len = max(max_len, cur_len)
            dp[i] = (dp[i] + dp[j - 1]) % mod

    min_sp = 0
    cur_len = 0
    lim_len = inf

    for ch in msg:
        cur_len += 1
        alpha = ord(ch) - ord('a')
        lim_len = min(lim_len, clim[alpha])

        if cur_len > lim_len:
            cur_len = 1
            min_sp += 1
            lim_len = clim[alpha]

    print(dp[n])
    print(max_len)
    print(min_sp + 1)

if __name__ == '__main__':
    solve()