#include <bits/stdc++.h> using namespace std; typedef long long lld; constexpr lld N = 1 << 12; constexpr lld M = 1000000007; lld t[N]; lld fac[N]; lld n, m; bitset<N> pos; int last[N]; lld res = 0; lld pot(lld x, lld y) { lld r = 1; while (y) { if (y & 1) (r *= x) %= M; (x *= x) %= M; y >>= 1; } return r; } void init() { pos[0] = 1; fac[0] = 1; for (lld i = 0; i < min(n, m); ++i){ fac[i + 1] = (fac[i] * (m - i)) % M; } } bool check() { for (int i = 1; i <= n; ++i) { if (last[t[i]] != 0) pos[i] = pos[last[t[i]]] | pos[last[t[i]] - 1]; else pos[i] = 0; last[t[i]] = i; } bool result = pos[n]; for (int i = 1; i <= n; ++i) { pos[i] = 0; last[i] = 0; } return result; } void solve(lld p, lld mx) { if (p > n) { if (check()) { res += fac[mx]; if (res >= M) res -= M; //for (int i = 1; i <= n; ++i) // printf("%lld ", t[i]); //printf(": %lld\n", fac[mx]); } return; } for (lld i = 1; i <= min(mx + 1, m); ++i) { t[p] = i; solve(p + 1, max(mx, i)); } } int main(){ scanf("%lld%lld", &n, &m); init(); solve(1, 0); printf("%lld", res); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | #include <bits/stdc++.h> using namespace std; typedef long long lld; constexpr lld N = 1 << 12; constexpr lld M = 1000000007; lld t[N]; lld fac[N]; lld n, m; bitset<N> pos; int last[N]; lld res = 0; lld pot(lld x, lld y) { lld r = 1; while (y) { if (y & 1) (r *= x) %= M; (x *= x) %= M; y >>= 1; } return r; } void init() { pos[0] = 1; fac[0] = 1; for (lld i = 0; i < min(n, m); ++i){ fac[i + 1] = (fac[i] * (m - i)) % M; } } bool check() { for (int i = 1; i <= n; ++i) { if (last[t[i]] != 0) pos[i] = pos[last[t[i]]] | pos[last[t[i]] - 1]; else pos[i] = 0; last[t[i]] = i; } bool result = pos[n]; for (int i = 1; i <= n; ++i) { pos[i] = 0; last[i] = 0; } return result; } void solve(lld p, lld mx) { if (p > n) { if (check()) { res += fac[mx]; if (res >= M) res -= M; //for (int i = 1; i <= n; ++i) // printf("%lld ", t[i]); //printf(": %lld\n", fac[mx]); } return; } for (lld i = 1; i <= min(mx + 1, m); ++i) { t[p] = i; solve(p + 1, max(mx, i)); } } int main(){ scanf("%lld%lld", &n, &m); init(); solve(1, 0); printf("%lld", res); return 0; } |