#include <iostream> using namespace std; struct seg { long long g_s; long long g; long long b_s; long long b; seg(long long g_s, long long g, long long b_s, long long b): g_s(g_s), g(g), b_s(b_s), b(b) {} seg(): g_s(0), g(0), b_s(0), b(0) {} }; int main() { int n, m; long long modulo; cin >> n >> m >> modulo; seg *prev = new seg[m]; seg *next = new seg[m]; //printf("--- 1\n"); prev[0] = seg(m, m, 0, 0); //printf("%d %d %d %d\n", m, m, 0, 0); for(int i = 1; i < m; ++i){ long long g_s = m - i; long long b_s = i; long long g = prev[i-1].g + g_s - b_s; long long b = prev[i-1].b + b_s; prev[i].g_s = g_s; prev[i].b_s = b_s; prev[i].g = g % modulo; prev[i].b = b % modulo; //printf("%lld %lld %lld %lld\n", g_s, g, b_s, b); } for(int i = 1; i < n; ++i) { //printf("--- 2\n"); long long gsum = (m * prev[0].g) % modulo; for (int j = 1; j < m; ++j) { gsum += (m - j) * prev[j].g_s; gsum %= modulo; } long long bsum_gbase = 0; long long bsum_bbase = 0; long long bsum = 0; next[0] = seg(gsum, gsum, 0, 0); //printf("%lld %lld %lld %lld\n", gsum, gsum, 0LL, 0LL); for (int j = 1; j < m; ++j) { gsum -= prev[j-1].g * (m-j+1); gsum += prev[j].g * (m-j); gsum -= prev[j].g_s * (m-j); gsum %= modulo; bsum -= bsum_gbase * (j-1); bsum_gbase += prev[j-1].g_s; bsum_bbase += prev[j-1].b_s; bsum += bsum_gbase * j; bsum -= bsum_bbase; bsum %= modulo; long long g_s = gsum; long long b_s = bsum; long long g = (next[j-1].g + g_s - b_s) % modulo; long long b = (next[j-1].b + b_s) % modulo; next[j] = seg(g_s, g, b_s, b); //printf("%lld %lld %lld %lld\n", g_s, g, b_s, b); } seg *temp = next; next = prev; prev = temp; } long long result = ((((prev[m-1].g + prev[m-1].b) % modulo) + modulo) % modulo); printf("%lld\n", result); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | #include <iostream> using namespace std; struct seg { long long g_s; long long g; long long b_s; long long b; seg(long long g_s, long long g, long long b_s, long long b): g_s(g_s), g(g), b_s(b_s), b(b) {} seg(): g_s(0), g(0), b_s(0), b(0) {} }; int main() { int n, m; long long modulo; cin >> n >> m >> modulo; seg *prev = new seg[m]; seg *next = new seg[m]; //printf("--- 1\n"); prev[0] = seg(m, m, 0, 0); //printf("%d %d %d %d\n", m, m, 0, 0); for(int i = 1; i < m; ++i){ long long g_s = m - i; long long b_s = i; long long g = prev[i-1].g + g_s - b_s; long long b = prev[i-1].b + b_s; prev[i].g_s = g_s; prev[i].b_s = b_s; prev[i].g = g % modulo; prev[i].b = b % modulo; //printf("%lld %lld %lld %lld\n", g_s, g, b_s, b); } for(int i = 1; i < n; ++i) { //printf("--- 2\n"); long long gsum = (m * prev[0].g) % modulo; for (int j = 1; j < m; ++j) { gsum += (m - j) * prev[j].g_s; gsum %= modulo; } long long bsum_gbase = 0; long long bsum_bbase = 0; long long bsum = 0; next[0] = seg(gsum, gsum, 0, 0); //printf("%lld %lld %lld %lld\n", gsum, gsum, 0LL, 0LL); for (int j = 1; j < m; ++j) { gsum -= prev[j-1].g * (m-j+1); gsum += prev[j].g * (m-j); gsum -= prev[j].g_s * (m-j); gsum %= modulo; bsum -= bsum_gbase * (j-1); bsum_gbase += prev[j-1].g_s; bsum_bbase += prev[j-1].b_s; bsum += bsum_gbase * j; bsum -= bsum_bbase; bsum %= modulo; long long g_s = gsum; long long b_s = bsum; long long g = (next[j-1].g + g_s - b_s) % modulo; long long b = (next[j-1].b + b_s) % modulo; next[j] = seg(g_s, g, b_s, b); //printf("%lld %lld %lld %lld\n", g_s, g, b_s, b); } seg *temp = next; next = prev; prev = temp; } long long result = ((((prev[m-1].g + prev[m-1].b) % modulo) + modulo) % modulo); printf("%lld\n", result); } |