#include <bits/stdc++.h> #define pb push_back #define fi first #define se second #define sz(x) (int)x.size() #define cat(x) cerr << #x << " = " << x << endl #define IOS cin.tie(0); ios_base::sync_with_stdio(0) using ll = long long; using namespace std; const int N = 3333333 + 500; int P; void add(int &a, int b) { a += b; if (a >= P) a -= P; if (a < 0) a += P; } int mul(int a, int b) { return ll(a) * b % P; } int C2(int n) { return 1LL * n * (n + 1) / 2 % P; } int n, m, f[2][N], g[2][N], df[N], dg[N], qf[N], qg[N]; void add(int *tab, int l, int r, int val) { add(tab[l], val); add(tab[r + 1], -val); } void increasing(int l, int r, int *tab, int *tab2, int diff) { add(tab, l, r, diff); add(tab2[r + 1], -mul(r - l + 1, diff)); } int main() { scanf("%d%d%d", &n, &m, &P); if (n == 1) return !printf("%d\n", C2(m)); if (n == 2) { ll all = 1LL * C2(m) * C2(m) % P; for (int i = 1; i <= m; ++i) all = (all - 2LL * (m - i + 1) % P * C2(i - 1) % P + P) % P; return !printf("%lld\n", all); } for (int i = 1; i <= m; ++i) g[0][i] = 1; for (int i = 1; i <= n; ++i) { int a = i % 2, b = (i + 1) % 2; for (int j = 1; j <= m; ++j) f[a][j] = g[a][j] = df[j] = dg[j] = qf[j] = qg[j] = 0; for (int j = 1; j <= m; ++j) { add(f[a], j, j, mul(j, g[b][j])); increasing(1, j - 1, dg, qg, g[b][j]); } for (int j = 1; j <= m; ++j) { add(f[a], j, m, mul(j, f[b][j])); increasing(1, j - 1, dg, qg, mul(m - j + 1, f[b][j])); add(g[a], j, m, mul(m - j, mul(j, f[b][j]))); increasing(j + 1, m, dg, qg, -mul(j, f[b][j])); } for (int j = 1; j <= m; ++j) { add(f[a][j], f[a][j - 1]); add(g[a][j], g[a][j - 1]); add(dg[j], dg[j - 1]); add(df[j], df[j - 1]); } for (int j = 1; j <= m; ++j) { add(df[j], df[j - 1]); add(df[j], qf[j]); add(f[a][j], df[j]); } for (int j = 1; j <= m; ++j) { add(dg[j], dg[j - 1]); add(dg[j], qg[j]); add(g[a][j], dg[j]); } } int out = 0; for (int i = 1; i <= m; ++i) add(out, f[n % 2][i]); printf("%d\n", out); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | #include <bits/stdc++.h> #define pb push_back #define fi first #define se second #define sz(x) (int)x.size() #define cat(x) cerr << #x << " = " << x << endl #define IOS cin.tie(0); ios_base::sync_with_stdio(0) using ll = long long; using namespace std; const int N = 3333333 + 500; int P; void add(int &a, int b) { a += b; if (a >= P) a -= P; if (a < 0) a += P; } int mul(int a, int b) { return ll(a) * b % P; } int C2(int n) { return 1LL * n * (n + 1) / 2 % P; } int n, m, f[2][N], g[2][N], df[N], dg[N], qf[N], qg[N]; void add(int *tab, int l, int r, int val) { add(tab[l], val); add(tab[r + 1], -val); } void increasing(int l, int r, int *tab, int *tab2, int diff) { add(tab, l, r, diff); add(tab2[r + 1], -mul(r - l + 1, diff)); } int main() { scanf("%d%d%d", &n, &m, &P); if (n == 1) return !printf("%d\n", C2(m)); if (n == 2) { ll all = 1LL * C2(m) * C2(m) % P; for (int i = 1; i <= m; ++i) all = (all - 2LL * (m - i + 1) % P * C2(i - 1) % P + P) % P; return !printf("%lld\n", all); } for (int i = 1; i <= m; ++i) g[0][i] = 1; for (int i = 1; i <= n; ++i) { int a = i % 2, b = (i + 1) % 2; for (int j = 1; j <= m; ++j) f[a][j] = g[a][j] = df[j] = dg[j] = qf[j] = qg[j] = 0; for (int j = 1; j <= m; ++j) { add(f[a], j, j, mul(j, g[b][j])); increasing(1, j - 1, dg, qg, g[b][j]); } for (int j = 1; j <= m; ++j) { add(f[a], j, m, mul(j, f[b][j])); increasing(1, j - 1, dg, qg, mul(m - j + 1, f[b][j])); add(g[a], j, m, mul(m - j, mul(j, f[b][j]))); increasing(j + 1, m, dg, qg, -mul(j, f[b][j])); } for (int j = 1; j <= m; ++j) { add(f[a][j], f[a][j - 1]); add(g[a][j], g[a][j - 1]); add(dg[j], dg[j - 1]); add(df[j], df[j - 1]); } for (int j = 1; j <= m; ++j) { add(df[j], df[j - 1]); add(df[j], qf[j]); add(f[a][j], df[j]); } for (int j = 1; j <= m; ++j) { add(dg[j], dg[j - 1]); add(dg[j], qg[j]); add(g[a][j], dg[j]); } } int out = 0; for (int i = 1; i <= m; ++i) add(out, f[n % 2][i]); printf("%d\n", out); return 0; } |