1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <stdio.h>
#define MIN(a,b) (((a)<(b))?(a):(b))
const int M=1000000007, C=5005;
using ll=long long;

ll dp[C][C][2]; //dims: len_seq x valid_colors x valid
int main(){
	int i, n, m, j;
	ll k;

	scanf ("%d %lld", &n, &k);

	dp[1][1][0] = k;
	for (i=1; i<n; i++){
		m = MIN(k, i+1);
		for (j=0; j<=m; j++){
			dp[i][j][0] %= M;
			dp[i][j][1] %= M;

			dp[i+1][j][1] += (dp[i][j][1]*j)%M; //Add valid color - end of valid
			dp[i+1][j+1][0] += (dp[i][j][1]*(k-j))%M; //Add invalid color - end of valid
			dp[i+1][j][1] += (dp[i][j][0]*j)%M; //Add valid color - end of invalid
			dp[i+1][j][0] += (dp[i][j][0]*(k-j))%M; //Add invalid color - end of invalid
		}
	}

	ll res=0;
	for (i=1; i<=n; i++) res = (res+dp[n][i][1])%M;
	if (res<0) res+=M;
	printf ("%lld\n", res);
return 0;}