ZONe2021 C問題 MAD TEAM の解説とふりかえり

残念ながら解けなかったC問題 MAD TEAMを解説してふりかえります。

atcoder.jp

問題の概要

N 人のメンバー候補がいて、それぞれの人は、パワー・スピード・テクニック・知識・発想力の 5 種類の能力値を持っています。
N 人のメンバー候補から総合力が最大になるように3名を選ぶ問題です。
総合力は、それぞれの能力についてメンバーの中の最大値をとって、その5つの能力値の中で最小の値です。

難しいポイント

Nが最大で3000なので、単純に全ての組み合わせでチェックすると {}_{3000}C_{3} = 4495501000なので、タイムアウトします。
私は当てずっぽうで、それぞれの能力で1番目、2番目、3番目の能力値を持っているメンバーを選んで、その中で全通り調べましたが、それは誤った答えになりました。
他にも、それぞれの能力値が最大の値(答え)となり得るかチェックする方法を考えましたが、それでも組み合わせを調べないといけなくなり、ループの中で最大3000の中で2つを選ぶ組み合わせを調べるのも、 {}_{3000}C_{2} = 4498500となり、タイムアウトになります。
組み合わせを使わないように工夫が必要です。

解き方の方針

公式解説を読んで目から鱗が落ちました。

最小値の最大値のような形の問題では、答えを二分探索し、「答えが  x 以上になるか?」という判定問題に持ち込むと簡単になることがあります。

そういう発想があったとは。。。

値の範囲が 1 \leq A_i, B_i, C_i, D_i, E_i \leq 10^9なので、答えとなる最大値を xとして、 xの値を1から 10^9 + 1の範囲で二分探索します。
この xを置くことで、 xより大きいか、小さいかという基準だけが必要な情報になります。
全ての能力値をx以上なら 1 あるいは未満なら 0で圧縮することができます。
そうすると1人のメンバーの能力値のパターンとしては、 2^5 = 32種類だけになります。
 x以上の能力値を持っているメンバーで全通りの組み合わせを見て、全ての能力で x以上の場合、二分探索の判定をTrueにします。
総合力として、能力値の中の最小値をとってるので、どこかの能力値が xを下回っているとFalse、全ての能力値が xを上回っているから、能力値は更新されることになります。

判定がTrueということは x以上の値ばかりなので、探索範囲の下限を中央値にして次のループに入ります。

実装方法

公式解説の実装を説明します。
atcoder.jp

二分探索の実装については@hamkoさんの記事がおすすめです。
qiita.com

実装のポイントは二分探索の判定関数です。

まずはメンバーの能力を探索している x以上かどうかの2進数で表現します。
 x = 18で、あるメンバーの能力が (A, B, C, D, E) = (6, 19, 20, 5, 1)のとき、
18より大きいところに1を立てます。
Aの値に対して 2^0、Bの値に対して 2^1、Cの値に対して 2^2、Dの値に対して 2^3、Eの値に対して 2^4を割り当てたとすると、 00110_{(2)} = 6になります。

2進数でメンバーの能力を表現するときのイメージ図
2進数でメンバーの能力を表現するときのイメージ図
for a in A:
    s.add(sum(1 << i for i in range(5) if a[i] >= x))

総合力として、能力値の中の最小値をとってるので、どこかの能力値が xを下回っているとFalse、全ての能力値が xを上回っているから、能力値は更新されることになります。
全ての能力値が xを上回っているということは、各メンバーの能力を OR でとって、全てのビットに1が立っていればTrueになります。
全てのビットに1が立っているということは、 2^5 - 1 = 31に等しいかどうか判定します。

for member1 in s:
    for member2 in s:
        for member3 in s:
            if member1 | member2 | member3 == 31:
                return True
return False

全体のコードは以下のようになりました。

N = int(input())
A = [list(map(int, input().split())) for i in range(N)]


def check(x):
    s = set()
    for a in A:
        s.add(sum(1 << i for i in range(5) if a[i] >= x))
    for member1 in s:
        for member2 in s:
            for member3 in s:
                if member1 | member2 | member3 == 31:
                    return True
    return False


ok = 0
ng = 10**9 + 1
while ng - ok > 1:
    mid = (ok + ng) // 2
    if check(mid):
        ok = mid
    else:
        ng = mid
print(ok)

圧縮の工夫はとてもためになりました。
そして今の私の実力では何時間悩んでも思いつかなかったと思います。
判定をTrueにする条件も感動しました。
学びの多い問題だったので復習します。