#include <iostream>
#include <vector>
using namespace std;
#define LL long long
LL modexp(LL base, LL exp, LL mod) {
LL res = 1;
base %= mod;
while (exp > 0) {
if (exp & 1) res = (res * base) % mod;
base = (base * base) % mod;
exp >>= 1;
}
return res;
}
LL modinv(LL x, LL mod) {
return modexp(x, mod - 2, mod);
}
void precomputeFactorials(vector<LL>& fact, vector<LL>& invfact, int maxN, LL p) {
fact[0] = 1;
for (int i = 1; i <= maxN; i++)
fact[i] = (fact[i-1] * i) % p;
invfact[maxN] = modinv(fact[maxN], p);
for (int i = maxN; i >= 1; i--)
invfact[i-1] = (invfact[i] * i) % p;
}
int compute_nmax(int a, int b, int c) {
int n = 0;
for (int i = 1; i <= b; i++) {
int cells = min(a, c + 1 - i);
if (cells <= 0) break;
n += cells;
}
return n;
}
LL compute_f_lambda(int a, int b, int c, int nmax, LL p, const vector<LL>& fact) {
LL hookProd = 1;
for (int i = 1; i <= b; i++) {
int rowCells = min(a, c + 1 - i);
if (rowCells <= 0) break;
for (int j = 1; j <= rowCells; j++) {
int arm = rowCells - j;
int leg = 0;
for (int i2 = i + 1; i2 <= b && j <= min(a, c + 1 - i2); i2++)
leg++;
int hook = arm + leg + 1;
hookProd = (hookProd * hook) % p;
}
}
return (fact[nmax] * modinv(hookProd, p)) % p;
}
LL count_max_permutations(int a, int b, int c, int nmax, LL p) {
vector<LL> fact(nmax + 1), invfact(nmax + 1);
precomputeFactorials(fact, invfact, nmax, p);
LL total = compute_f_lambda(a, b, c, nmax, p, fact);
total = (total * total) % p;
int d = a + b - c;
if (d <= 1) return total;
LL res = 0;
for (int mask = 0; mask < (1 << d); mask++) {
int bits = __builtin_popcount(mask);
int a2 = a, b2 = b, c2 = c;
for (int i = 0; i < d; i++) {
if (mask & (1 << i)) {
if (i < a) a2--;
else b2--;
c2--;
}
}
if (a2 <= 0 || b2 <= 0 || c2 < max(a2, b2)) continue;
int n2 = compute_nmax(a2, b2, c2);
LL flambda2 = compute_f_lambda(a2, b2, c2, n2, p, fact);
flambda2 = (flambda2 * flambda2) % p;
if (bits % 2 == 0)
res = (res + flambda2) % p;
else
res = (res - flambda2 + p) % p;
}
return res;
}
int main() {
int a, b, c;
LL p;
cin >> a >> b >> c >> p;
int n = compute_nmax(a, b, c);
LL count = count_max_permutations(a, b, c, n, p);
cout << n << " " << count << endl;;
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | #include <iostream> #include <vector> using namespace std; #define LL long long LL modexp(LL base, LL exp, LL mod) { LL res = 1; base %= mod; while (exp > 0) { if (exp & 1) res = (res * base) % mod; base = (base * base) % mod; exp >>= 1; } return res; } LL modinv(LL x, LL mod) { return modexp(x, mod - 2, mod); } void precomputeFactorials(vector<LL>& fact, vector<LL>& invfact, int maxN, LL p) { fact[0] = 1; for (int i = 1; i <= maxN; i++) fact[i] = (fact[i-1] * i) % p; invfact[maxN] = modinv(fact[maxN], p); for (int i = maxN; i >= 1; i--) invfact[i-1] = (invfact[i] * i) % p; } int compute_nmax(int a, int b, int c) { int n = 0; for (int i = 1; i <= b; i++) { int cells = min(a, c + 1 - i); if (cells <= 0) break; n += cells; } return n; } LL compute_f_lambda(int a, int b, int c, int nmax, LL p, const vector<LL>& fact) { LL hookProd = 1; for (int i = 1; i <= b; i++) { int rowCells = min(a, c + 1 - i); if (rowCells <= 0) break; for (int j = 1; j <= rowCells; j++) { int arm = rowCells - j; int leg = 0; for (int i2 = i + 1; i2 <= b && j <= min(a, c + 1 - i2); i2++) leg++; int hook = arm + leg + 1; hookProd = (hookProd * hook) % p; } } return (fact[nmax] * modinv(hookProd, p)) % p; } LL count_max_permutations(int a, int b, int c, int nmax, LL p) { vector<LL> fact(nmax + 1), invfact(nmax + 1); precomputeFactorials(fact, invfact, nmax, p); LL total = compute_f_lambda(a, b, c, nmax, p, fact); total = (total * total) % p; int d = a + b - c; if (d <= 1) return total; LL res = 0; for (int mask = 0; mask < (1 << d); mask++) { int bits = __builtin_popcount(mask); int a2 = a, b2 = b, c2 = c; for (int i = 0; i < d; i++) { if (mask & (1 << i)) { if (i < a) a2--; else b2--; c2--; } } if (a2 <= 0 || b2 <= 0 || c2 < max(a2, b2)) continue; int n2 = compute_nmax(a2, b2, c2); LL flambda2 = compute_f_lambda(a2, b2, c2, n2, p, fact); flambda2 = (flambda2 * flambda2) % p; if (bits % 2 == 0) res = (res + flambda2) % p; else res = (res - flambda2 + p) % p; } return res; } int main() { int a, b, c; LL p; cin >> a >> b >> c >> p; int n = compute_nmax(a, b, c); LL count = count_max_permutations(a, b, c, n, p); cout << n << " " << count << endl;; return 0; } |
English