js2021 D問題 Nowhere P 数学問題を解説

js2021 D問題 Nowhere P を噛み砕いて解説します。

atcoder.jp

問題

2 以上の整数 P が与えられます。これはあなたの嫌いな数です。
整数の列  A_1, A_2, ..., A_N が以下の条件を満たすとき、この列を とても良い列と呼びます。
1 以上 N 以下のどの整数 i についても、 A_1 + A_2 + ... + A_i は P の倍数でない
各要素が 1 以上 P−1 以下であるような長さ N の整数列は全部で  {(P−1)}^N 個存在しますが、このうち とても良い列はいくつあるでしょうか?

ただし、答えは非常に大きくなることがあるので、答えを  (10^9 + 7) で割った余りを出力してください。

解き方のイメージ

まずNやPの制約が最大で 10^9になってる時点で全探索はないなと考えます。
何かしら工夫が必要なのですが、「何かしらの工夫」を見つけるために、程よく小さめなNやPのケースを手を動かして考えてみます。
サンプルにある N = 3, P = 3で考えてみます。
N = 3なので、 A_1 + A_2 + A_3が3の倍数でないものがとても良い列になります。
 A_1を選ぶとすると、P = 3なので、1 か2のどちらかになります。
次に、 A_2を選ぶ時に、 A_1 + A_2が3で割り切れないように選びます。
 A_1 = 1の時は A_2 = 1 A_1 = 2の時は A_2 = 2です。
次に A_3を選ぶときも A_1 + A_2 + A_3が3で割り切れないように選びます。
 A_1 + A_2 = 1 + 1 = 2の時は A_3 = 2 A_1 + A_2 + A_3 = 1 + 1 + 2 = 4です。
 A_1 + A_2 = 2 + 2 = 4の時は A_3 = 1 A_1 + A_2 + A_3 = 2 + 2 + 1 = 5です。
Nの数を増やして、N = 4, P = 3だとしても、次の A_4の選び方はそれぞれの分岐で1通りしかないことがわかります。
 A_1 + A_2 + A_3 = 1 + 1 + 2 = 4の時は A_4 = 1 A_1 + A_2 + A_3 + A_4 = 1 + 1 + 2 + 1 = 5です。
 A_1 + A_2 + A_3 = 2 + 2 + 1 = 5の時は A_4 = 2 A_1 + A_2 + A_3 + A_4 = 2 + 2 + 1 + 2 = 7です。


N = 3, P = 4 で考えます。
 A_1は特に制約はなく1〜3の間で好きに選べます。
 A_2 A_1 = 1の時、4の倍数にならないように A_2 = 1 or 2 の2通りの中から選びます。
もちろん A_1 = 2の時、4の倍数にならないように A_2 = 1 or 3の2通りから選びますし、
 A_1 = 3の時も4の倍数にならないように A_2 = 2 or 3の2通りから選びます。
次に A_3の選び方を考えます。
 A_1 + A_2 = 1 + 1 = 2の時、4の倍数にならないように A_3 = 1 or 3の2通りから選びます。
他の分岐も同じです。全部4の倍数にならないように3 - 1通りの2通りから選ぶことになります。
N = 3までの分岐の樹形図を描きました。

N=3、P=4の時の樹形図
N=3、P=4の時の樹形図

これまでのサンプルを踏まえてとても良い列の場合の数を考えると、
最初の選び方は (P - 1)通り、そのあとは (P - 2)通りを N - 1回繰り返します。

なので、答えは (P - 1) \times (P - 2)^{N - 1}です。

 10^9 + 7で割った余りが求められています。

 10^9 + 7の計算は@drkenさんの以下の記事を参考にしました。
qiita.com

コードは以下のようになりました。

N, P = map(int, input().split())
MOD = 10**9 + 7


# a^n modを計算する
def modpow(a: int, n: int, m: int):
    res = 1
    while n > 0:
        if n & 1:
            res = res * a % m
        a = a * a % m
        n = n >> 1
    return res


ans = ((P - 1) % MOD) * modpow((P - 2), (N - 1), MOD)
print(ans % MOD)