順列の問題はitertoolsを使いたおす

今回は順列の問題を解きました。
問題をざっくりというと、N個の街があって、全部の町を巡回します。
巡回する順番は無数にありますが、それらの全てのパターンの距離の平均値を求める問題です。

atcoder.jp

方針としてはNは2〜8なので、街の順番の配列の全パターンを考えても8! = 40320通り。
そこから移動距離を求めるのは 8-1通り。
 8! * (8-1)=282240なので、計算量は大丈夫そう。
街の順番の配列を生成して、その配列の要素ごとに移動距離を求めて、最後に割り算ををします。

街の順番の配列の全パターンを作るのは順列を作ることになるので、迷わずitertools.permutationを使いました。
練習になってない気がしてしまいますが。。。

docs.python.org

itertoolsはかなり便利で、iterableな配列を引数に渡すと、permutations(順列)だけでなくcombinations(組み合わせ)、accumulate(累積和)の配列を返してくれます。

from itertools import permutations
lst = [1, 2, 3]
list(permutations(lst))
# [(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]

permutationのおかげですんなりと問題が解けました。

from itertools import permutations
from math import sqrt
N = int(input())
towns = []
for _ in range(N):
    x, y = [int(i) for i in input().split()]
    towns.append((x, y))
towns_distance = [[0 for _ in range(N)] for _ in range(N)]
for i in range(N):
    for j in range(N):
        if i == j:
            continue
        towns_distance[i][j] = sqrt((towns[i][0] - towns[j][0])**2 + (towns[i][1] - towns[j][1])**2)
patterns = list(permutations([i for i in range(N)]))
total = 0
for m in patterns:
    for i in range(N-1):
        total += towns_distance[m[i]][m[i+1]]
print(round(total/len(patterns), 10))

街の間の距離を事前に配列にメモしましたが、あまり意味なかったかも。

abc193のC問題 整数系の問題を解く

AtcoderのABC193に参加して解けたC問題のふりかえりです。
整数 N が与えられて、1 以上 N 以下の整数のうち、2 以上の整数 a,bを用いてa^bと表せないものはいくつあるでしょうか?
読解力がない自分は短い問題の方が好きです。
atcoder.jp

まずNは最大で 10^{10}なので、全探索は無理。
aが例えば最小の2のとき、bが最小の2のとき、それぞれのaとbはいくつになるんだろうと考えつけば、
aとbの探索範囲が現実的になってきます。
bが2の時は、 a^2 <= Nなので、aの範囲は 2 <= a <= \sqrt{N}
aが2の時は、 2^b <= Nなので、bの範囲は 2 <= b <= \log_{2} N
aもbも整数なので、適宜値を切り捨てたり、切り上げたりします。

max_a = ceil(sqrt(N))
max_b = floor(log2(N))

あとは a^bの値が重複して数えないようにして、Nからa^bが成立した回数を引いてあげれば答えです。

from math import log2
from math import floor, ceil
from math import sqrt
N = int(input())
max_b = floor(log2(N))
max_a = ceil(sqrt(N))
ok_set = set()
for b in range(2, max_b + 1):
    for a in range(2, max_a + 1):
        if a ** b <= N:
            ok_set.add(a**b)
        else:
            break
count = len(ok_set)
print(N - count)

最初に対数をとって、あれこれと不等式を数学的に解けないか考えて、無駄な時間を過ごしてしまったことは反省。。。
公式の解説は頭が良すぎる簡潔な表現で、コードも簡潔で劣等感を感じました。

2進数の扱い方 おせんべいの問題

今回はビット全探索を使いつつ2進数で操作する問題を解きました。
2進数の扱いについて、いくつかポイントを書きます。

atcoder.jp

問題をざっくりいうと、せんべいを表と裏を焼かないといけなくて、
表面を焼いてる途中で地震が起きて、何枚か裏返ったんだけど、まだ裏返ってないのがあるから、
行か列でせんべいをひっくり返して、最大で何枚のせんべいを裏面にして出荷できるか。

問題文が長いのもあるのですが、この問題のルールを理解するのに10分以上かかりました。。。
私の読解力に問題があるとは思うけど、慣れればすんなり理解できるのでしょうか。。。

方針としては、問題文でヒントがあるように、行数は最大で10というところに着目して、
行をひっくり返すパターンはビット全探索します。
せいぜい、2^{10} = 1024通り

どの列をひっくり返すかは、列を見て、表が多ければひっくり返します。
問題で問われているのは、操作をした後の裏面を焼いている枚数なので、
列で探索して表面と裏面の数で、多い数を足しあげていきます。

ひっくり返す操作はXOR(排他的論理和)を使う

たまたまzobrist hashingを知識として知っていたので、すんなり気づけました。
XORは値が等しいときはFalse、値が異なるときはTrueになります。
なので1をXORしてあげれば、0 XOR 1 -> 1になりますし、1 XOR 1 -> 0になります。

formatで10進数から2進数へ変換

format関数で10進数から2進数へ変換しました。
変換するときに行の数が最大で10なので、10桁になるようにしています。

# bitは整数が入ります。ビット全探索で使用するbit。
format(bit, '010b')

結局一番苦労したのは列単位で探索するところでした。
例えば3行4列の問題で、ビットが0000000110のとき、1行目と2行目をひっくり返すつもりなのですが、
その時に1行目のk列目をどうやってひろってきて、XORしたら良いかかなり悩み、不具合などもあって1時間以上苦労しました。。。

結局は涙ぐましい努力で乗り切りました。

int(flip_row_pos[10 - R + k])

以下は全文です。

R, C = list(map(int, input().split()))
senbei_pos = []
ans = 0
for _ in range(R):
    pos = list(map(int, input().split()))
    senbei_pos.append(pos)

for bit in range(2**R):
    total = 0
    copied_pos = senbei_pos[:]
    # Rの上限が10なので10桁の2進数になるように0で埋める
    flip_row_pos = list(format(bit, '010b'))
    for j in range(C):
        column = [p[j] for p in copied_pos]
        one_count = sum([column[k] ^ int(flip_row_pos[10 - R + k])
                         for k in range(R)])
        zero_count = R - one_count
        total += max(zero_count, one_count)
    ans = max(ans, total)
print(ans)

各桁の和の求め方いろいろ

f:id:hrksb5029:20210225022431p:plain
2つ数字が与えられて、各桁の和が大きい方の和を出力する問題です。
自分のやり方があまりにいけてなくて、他の方の解答を見て、目から鱗が落ちました。

atcoder.jp

基本方針として100の位と10の位と1の位の数字をそれぞれ求めて合計を出すことを考えます。

例えば、235という数字があったとします。
私の場合100の位の数字は、100で割った商で求めます。
235 // 100 -> 2
ここからがいけてなくて。。。
その後に、元の数字235から200を引きます。235 - 200 = 35
35を10で割った商から10の位を求めて、また35から30を引いて。。。

途中で書きたくなくなるくらいお粗末なやり方だなと思いました。

もっと美しい方法1

各桁を求めて最後に足す方針は同じです。
100の位は100で割った商で求めます。(ここまでは同じ)
10の位は10で割った商から10で割ったあまりを計算して求めます。
235だったら、235 // 10 -> 23
そこから10で割ったあまりは 23 % 10 -> 2

(235 // 10) % 10
# 2

1の位は10で割ったあまりで求めます。

もっと美しい方法2

入力が文字列なので、各文字を整数に変換してそのまま足します。

sum(map(int, '235'))
# 10

最後のやり方は、シンプルで、「あぁそうだよなー」と思いました。

あまり自分のダメさ加減を卑下しても仕方がないので、
こういう細かいことを積み重ねつつ、1つ賢くなったと前向きに考えていこうと思います。

やっとビット全探索に慣れてきた気がする

今回の問題はざっくりいうと、N人の議員がいて、M個の知り合い関係があって、派閥に含まれる議員は全員知り合いという条件のとき、派閥の最大人数は何人になるかという問題。
atcoder.jp
まず着目するのはN人の最大値が12というところで、ビット全探索すれば2^{12} = 4096通りになりそうと予想。
そこから12人の2人関係について列挙すると、{}_{12} \mathrm{C} _2 = 12 * 11 / 2 = 66通り。掛け合わせると265914で、計算量が億とかうん千万に届かないので、全列挙の方針で問題なしと判断しました。

ビットをたてた議員の組み合わせについて全員知り合いかどうか関係を全部チェックします。
最初は関係をtupleのsetに保存して、O(N)のオーダーでチェックしていたのですが、ロジックが甘いところがあり、74のテストケースで誤りが2つになりました。。。たぶん保存しているtupleの順番と検索しようとしている順番が逆のときにダメなんだと思います。(2, 3)の関係のときに、(3, 2)で検索してもだめという感じ。

関係をN * Nの配列に保存して、関係があれば1を入れるようにしました。

ただ、PyPyで提出したらなんとMLE(メモリオーバー)になってしまいました。。。
たまにあるのですが、なぜなのかいまだにわかりません。
Pythonで提出して無事にクリアしました。

ビット全探索に慣れてきた気がします。そしてこの感覚が勘違いでないことを祈りたい。あとこの感覚を忘れないうちに身体に染み込ませたい。

ビット全探索についてはけんちょんさんのとっても丁寧な説明で理解が進みました。

drken1215.hatenablog.com

from itertools import combinations
N, M = list(map(int, input().split()))
friends = [[0 for _ in range(N)] for _ in range(N)]
max_count = 0
if M == 0:
    print(1)
    exit()
for _ in range(M):
    x, y = list(map(int, input().split()))
    friends[x-1][y-1] = 1
    friends[y-1][x-1] = 1
for bit in range(1 << N):
    fr_list = []
    for i in range(N):
        if bit & (1 << i):
            fr_list.append(i)
    hantei = True
    for com in list(combinations(fr_list, 2)):
        if friends[com[0]][com[1]] == 0 or friends[com[1]][com[0]] == 0:
            hantei = False
            break
    if hantei:
        max_count = max(max_count, len(fr_list))
print(max_count)

問題の解説を読んで劣等感をおぼえる

f:id:hrksb5029:20210223232004p:plain
今回は工夫して全列挙すると解ける問題に挑戦してみました。
問題をざっくりと言うと、スーパーに入ってお目当ての品物2つを買うお客さんがN人(最大で30人)いて、入口と出口の位置をどこにしたら移動量が最小になるかを解く問題。

atcoder.jp

入力例と答えから、なんとなく入口と出口はN人のうちの誰かの品物の位置に合わせた方が良いと思いつきました。
入口の候補がN個、出口の候補がN個、そして入口と出口を固定した後にN人の移動距離を計算するので、計算量はざっと
O(N^3)
Nは最大で30なので30*30*30=27000なら全列挙できる!!全列挙して解きました。
あっさりと正解できたのですが、解説を読んで劣等感を味わった一文は次の通りです。

N個の数a_1, a_2, ..., a_nがあった時、|x-a_1| + |x-a_2| + ... + |x-a_n|の最小値を求めるときに、最小値となるxの値はa_1, a_2, ..., a_nの中央値となる。

えっ、そうなの?解説にあった感覚的な説明では理解できず色々と調べてしまいました。
結局、RPubs - 中央値を実感するの数式を使った説明で納得しました。
回帰の誤差、2乗するか絶対値をとるか - Qiitaという記事も面白くて、最小二乗法しか最近は使ってなかった私は最小絶対値法という存在と誤差の奥深さにハマり、競プロとは異なる世界で1時間ほど脱線していました。

話を問題に戻すと、移動距離は、入口から1つ目の商品の距離 + 2つの品物の距離 + 2つ目の商品から出口までの距離のN人分の合計になるので、
入口から1つ目の商品の距離は、入口の位置をs、1つ目の商品の位置をa_1, ..., a_nとしたときに、|s-a_1 + |s-a_2| + ... + |s-a_n|
2つ目の商品から出口までの距離は、出口の位置をe、2つ目の商品の位置をb_1, ..., b_nとしたときに、|e-b_1| + |e-b_2| + ... + |e-b_n|
となって、
結局、a_1, ..., a_nb_1, ..., b_nの中央値が移動距離が最小になる入口と出口の位置になります。
中央値はソートして、配列の要素数を2で割った切り捨て値の位置の値になります。
言葉だと伝わらなさそうなので、

org_list.sort()
median = org_list[len(org_list) // 2]

sortの計算量がO(n*log(n))なので、O(n*log(n))で求められることになります。
O(N^3)しか思いつかなかった。。。自力でこの頭の良い解法に気づける日は来るのだろうか。。。

N = int(input())
ab_arr = []
minimum = 1 << 64 - 1
for _ in range(N):
    a, b = list(map(int, input().split()))
    ab_arr.append((a, b))
for i in range(N):
    for j in range(N):
        iri = ab_arr[i][0]
        degu = ab_arr[j][1]
        amount = 0
        for ab in ab_arr:
            amount += abs(iri - ab[0]) + abs(ab[0] - ab[1]) + abs(ab[1] - degu)
        minimum = min(amount, minimum)
print(minimum)

ビット全探索の問題にハマる

ビット全探索の勉強のためにABC128のC問題に挑戦しました。

atcoder.jp

まずは念のために計算量の計算。スイッチの数が10個なので、それぞれのスイッチが消えているか、ついているかを全部列挙しても2**10 = 1024通り。

電球の数も最大で10で、それぞれの電球が点灯しているかチェックするためのスイッチのパターンも最大で10通り。

ざっくりと1024 * 10 * 10 = 102400なので、まぁ全然大丈夫なオーダーだなと思って、サラサラと書いたのですがバグを埋め込んでしまい、1時間もバグを見つけるのにかかってしまいました。。。

原因はインプットのスイッチの番号は1番から始まるのに対して、スイッチの位置に対応したビットシフトをするときに1つ多めにビットシフトしていました。

インプットのスイッチの位置を-1して解消したのですが、デバッガで数字見てても思い込みがあったのかなんなのか、気づけませんでした。
なんだかなぁ。。。

N, M = list(map(int, input().split()))
s_arr = []
count = 0
for _ in range(M):
    kss = list(map(int, input().split()))
    s_arr.append(kss[1:])
p = list(map(int, input().split()))
for bit in range(1 << N):
    all_bright = True
    for i, s in enumerate(s_arr):
        c = 0
        for ss in s:
            if bit & (1 << (ss - 1)):
                c += 1
        if c % 2 != p[i]:
            all_bright = False
            break
    if all_bright:
        count += 1
print(count)