1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include<cstdio>

int main() {
    int N, M;
    long long P;
    scanf("%d %d %lld", &N, &M, &P);

    long long dp[2][M+1];
    for (int k = 1; k <= M; ++k)
        dp[1][k] = k;

    long long dps[M+1];
    dps[0] = 0;
    long long wrong_from_bottom_sum;
    for (int i = 2; i <= N; ++i) {
      
        for (int k = 1; k <= M; k++)
            dps[k] = (dps[k-1] + dp[(i-1)%2][k]) % P;

        wrong_from_bottom_sum = 0;
        for (int k = 1; k <= M; ++k) {
            dp[i%2][k] = ((long long)k * (dps[M] + P - dps[M-k]) + P - wrong_from_bottom_sum) % P;
            wrong_from_bottom_sum = (wrong_from_bottom_sum + dps[k]) % P;
        }
    }

    long long result = 0;
    for (int k = M; k >= 1; --k) {
        result = (result + dp[N%2][k]) % P;
    }
    printf("%lld\n", result);

    return 0;
}