1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <bits/stdc++.h>
using namespace std;
typedef long long lld;

constexpr lld N = 1 << 12;
constexpr lld M = 1000000007;
lld t[N];
lld fac[N];
lld n, m;
bitset<N> pos;
int last[N];
lld res = 0;

lld pot(lld x, lld y) {
	lld r = 1;
	
	while (y) {
		if (y & 1)
			(r *= x) %= M;
		(x *= x) %= M;
		y >>= 1;
	}
	
	return r;
}

void init() {
	pos[0] = 1;
	fac[0] = 1;
	
	for (lld i = 0; i < min(n, m); ++i){
		fac[i + 1] = (fac[i] * (m - i)) % M;
	}
}

bool check() {
	for (int i = 1; i <= n; ++i) {
		if (last[t[i]] != 0) 
			pos[i] = pos[last[t[i]]] | pos[last[t[i]] - 1];
		else
			pos[i] = 0;
		
		last[t[i]] = i;
	}
	
	bool result = pos[n];
	
	for (int i = 1; i <= n; ++i) {
		pos[i] = 0;
		last[i] = 0;
	}
	
	return result;
}

void solve(lld p, lld mx) {
	if (p > n) {
		if (check()) {
			res += fac[mx];
			if (res >= M) res -= M;
			
			//for (int i = 1; i <= n; ++i)
			//	printf("%lld ", t[i]);
			
			//printf(": %lld\n", fac[mx]);
		}
		return;
	}
	
	for (lld i = 1; i <= min(mx + 1, m); ++i) {
		t[p] = i;
		solve(p + 1, max(mx, i));
	}
}

int main(){
	scanf("%lld%lld", &n, &m);
	
	init();
	
	solve(1, 0);
	
	printf("%lld", res);
	
	return 0;
}