#include <bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
#define sz(x) (int)x.size()
#define cat(x) cerr << #x << " = " << x << endl
#define IOS cin.tie(0); ios_base::sync_with_stdio(0)
using ll = long long;
using namespace std;
const int N = 3333333 + 500;
int P;
void add(int &a, int b) {
a += b;
if (a >= P)
a -= P;
if (a < 0)
a += P;
}
int mul(int a, int b) {
return ll(a) * b % P;
}
int C2(int n) {
return 1LL * n * (n + 1) / 2 % P;
}
int n, m, f[2][N], g[2][N], df[N], dg[N], qf[N], qg[N];
void add(int *tab, int l, int r, int val) {
add(tab[l], val);
add(tab[r + 1], -val);
}
void increasing(int l, int r, int *tab, int *tab2, int diff) {
add(tab, l, r, diff);
add(tab2[r + 1], -mul(r - l + 1, diff));
}
int main() {
scanf("%d%d%d", &n, &m, &P);
if (n == 1)
return !printf("%d\n", C2(m));
if (n == 2) {
ll all = 1LL * C2(m) * C2(m) % P;
for (int i = 1; i <= m; ++i)
all = (all - 2LL * (m - i + 1) % P * C2(i - 1) % P + P) % P;
return !printf("%lld\n", all);
}
for (int i = 1; i <= m; ++i)
g[0][i] = 1;
for (int i = 1; i <= n; ++i) {
int a = i % 2, b = (i + 1) % 2;
for (int j = 1; j <= m; ++j)
f[a][j] = g[a][j] = df[j] = dg[j] = qf[j] = qg[j] = 0;
for (int j = 1; j <= m; ++j) {
add(f[a], j, j, mul(j, g[b][j]));
increasing(1, j - 1, dg, qg, g[b][j]);
}
for (int j = 1; j <= m; ++j) {
add(f[a], j, m, mul(j, f[b][j]));
increasing(1, j - 1, dg, qg, mul(m - j + 1, f[b][j]));
add(g[a], j, m, mul(m - j, mul(j, f[b][j])));
increasing(j + 1, m, dg, qg, -mul(j, f[b][j]));
}
for (int j = 1; j <= m; ++j) {
add(f[a][j], f[a][j - 1]);
add(g[a][j], g[a][j - 1]);
add(dg[j], dg[j - 1]);
add(df[j], df[j - 1]);
}
for (int j = 1; j <= m; ++j) {
add(df[j], df[j - 1]);
add(df[j], qf[j]);
add(f[a][j], df[j]);
}
for (int j = 1; j <= m; ++j) {
add(dg[j], dg[j - 1]);
add(dg[j], qg[j]);
add(g[a][j], dg[j]);
}
}
int out = 0;
for (int i = 1; i <= m; ++i)
add(out, f[n % 2][i]);
printf("%d\n", out);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | #include <bits/stdc++.h> #define pb push_back #define fi first #define se second #define sz(x) (int)x.size() #define cat(x) cerr << #x << " = " << x << endl #define IOS cin.tie(0); ios_base::sync_with_stdio(0) using ll = long long; using namespace std; const int N = 3333333 + 500; int P; void add(int &a, int b) { a += b; if (a >= P) a -= P; if (a < 0) a += P; } int mul(int a, int b) { return ll(a) * b % P; } int C2(int n) { return 1LL * n * (n + 1) / 2 % P; } int n, m, f[2][N], g[2][N], df[N], dg[N], qf[N], qg[N]; void add(int *tab, int l, int r, int val) { add(tab[l], val); add(tab[r + 1], -val); } void increasing(int l, int r, int *tab, int *tab2, int diff) { add(tab, l, r, diff); add(tab2[r + 1], -mul(r - l + 1, diff)); } int main() { scanf("%d%d%d", &n, &m, &P); if (n == 1) return !printf("%d\n", C2(m)); if (n == 2) { ll all = 1LL * C2(m) * C2(m) % P; for (int i = 1; i <= m; ++i) all = (all - 2LL * (m - i + 1) % P * C2(i - 1) % P + P) % P; return !printf("%lld\n", all); } for (int i = 1; i <= m; ++i) g[0][i] = 1; for (int i = 1; i <= n; ++i) { int a = i % 2, b = (i + 1) % 2; for (int j = 1; j <= m; ++j) f[a][j] = g[a][j] = df[j] = dg[j] = qf[j] = qg[j] = 0; for (int j = 1; j <= m; ++j) { add(f[a], j, j, mul(j, g[b][j])); increasing(1, j - 1, dg, qg, g[b][j]); } for (int j = 1; j <= m; ++j) { add(f[a], j, m, mul(j, f[b][j])); increasing(1, j - 1, dg, qg, mul(m - j + 1, f[b][j])); add(g[a], j, m, mul(m - j, mul(j, f[b][j]))); increasing(j + 1, m, dg, qg, -mul(j, f[b][j])); } for (int j = 1; j <= m; ++j) { add(f[a][j], f[a][j - 1]); add(g[a][j], g[a][j - 1]); add(dg[j], dg[j - 1]); add(df[j], df[j - 1]); } for (int j = 1; j <= m; ++j) { add(df[j], df[j - 1]); add(df[j], qf[j]); add(f[a][j], df[j]); } for (int j = 1; j <= m; ++j) { add(dg[j], dg[j - 1]); add(dg[j], qg[j]); add(g[a][j], dg[j]); } } int out = 0; for (int i = 1; i <= m; ++i) add(out, f[n % 2][i]); printf("%d\n", out); return 0; } |
English