#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
constexpr lld N = 1 << 12;
constexpr lld M = 1000000007;
lld t[N];
lld fac[N];
lld n, m;
bitset<N> pos;
int last[N];
lld res = 0;
lld pot(lld x, lld y) {
lld r = 1;
while (y) {
if (y & 1)
(r *= x) %= M;
(x *= x) %= M;
y >>= 1;
}
return r;
}
void init() {
pos[0] = 1;
fac[0] = 1;
for (lld i = 0; i < min(n, m); ++i){
fac[i + 1] = (fac[i] * (m - i)) % M;
}
}
bool check() {
for (int i = 1; i <= n; ++i) {
if (last[t[i]] != 0)
pos[i] = pos[last[t[i]]] | pos[last[t[i]] - 1];
else
pos[i] = 0;
last[t[i]] = i;
}
bool result = pos[n];
for (int i = 1; i <= n; ++i) {
pos[i] = 0;
last[i] = 0;
}
return result;
}
void solve(lld p, lld mx) {
if (p > n) {
if (check()) {
res += fac[mx];
if (res >= M) res -= M;
//for (int i = 1; i <= n; ++i)
// printf("%lld ", t[i]);
//printf(": %lld\n", fac[mx]);
}
return;
}
for (lld i = 1; i <= min(mx + 1, m); ++i) {
t[p] = i;
solve(p + 1, max(mx, i));
}
}
int main(){
scanf("%lld%lld", &n, &m);
init();
solve(1, 0);
printf("%lld", res);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | #include <bits/stdc++.h> using namespace std; typedef long long lld; constexpr lld N = 1 << 12; constexpr lld M = 1000000007; lld t[N]; lld fac[N]; lld n, m; bitset<N> pos; int last[N]; lld res = 0; lld pot(lld x, lld y) { lld r = 1; while (y) { if (y & 1) (r *= x) %= M; (x *= x) %= M; y >>= 1; } return r; } void init() { pos[0] = 1; fac[0] = 1; for (lld i = 0; i < min(n, m); ++i){ fac[i + 1] = (fac[i] * (m - i)) % M; } } bool check() { for (int i = 1; i <= n; ++i) { if (last[t[i]] != 0) pos[i] = pos[last[t[i]]] | pos[last[t[i]] - 1]; else pos[i] = 0; last[t[i]] = i; } bool result = pos[n]; for (int i = 1; i <= n; ++i) { pos[i] = 0; last[i] = 0; } return result; } void solve(lld p, lld mx) { if (p > n) { if (check()) { res += fac[mx]; if (res >= M) res -= M; //for (int i = 1; i <= n; ++i) // printf("%lld ", t[i]); //printf(": %lld\n", fac[mx]); } return; } for (lld i = 1; i <= min(mx + 1, m); ++i) { t[p] = i; solve(p + 1, max(mx, i)); } } int main(){ scanf("%lld%lld", &n, &m); init(); solve(1, 0); printf("%lld", res); return 0; } |
English