#include <bits/stdc++.h>
#define ll long long
#define mp make_pair
#define fi first
#define se second
#define pb push_back
#define vi vector<int>
#define pi pair<int, int>
#define mod 1000000007
template<typename T> bool chkmin(T &a, T b){return (b < a) ? a = b, 1 : 0;}
template<typename T> bool chkmax(T &a, T b){return (b > a) ? a = b, 1 : 0;}
ll ksm(ll a, ll b) {if (b == 0) return 1; ll ns = ksm(a, b >> 1); ns = ns * ns % mod; if (b & 1) ns = ns * a % mod; return ns;}
using namespace std;
int n, m;
const int maxn = 3005;
ll dp[maxn][maxn][2]; // a filled, b valid characters, valid or not
int main() {
cin >> n >> m;
dp[0][0][1] = 1;
for (int i = 0; i < n; i++) {
for (int j = 0; j <= i; j++)
for (int k = 0; k < 2; k++) {
// valid chara
dp[i + 1][j][1] = (dp[i + 1][j][1] + dp[i][j][k] * j) % mod;
dp[i + 1][j + k][0] = (dp[i + 1][j + k][0] + (m - j) * dp[i][j][k]) % mod;
}
}
ll ans = 0;
for (int j = 0; j <= n; j++)
ans += dp[n][j][1];
ans %= mod;
if (ans < 0) ans += mod;
cout << ans << endl;
return (0-0); // <3 cxr
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | #include <bits/stdc++.h> #define ll long long #define mp make_pair #define fi first #define se second #define pb push_back #define vi vector<int> #define pi pair<int, int> #define mod 1000000007 template<typename T> bool chkmin(T &a, T b){return (b < a) ? a = b, 1 : 0;} template<typename T> bool chkmax(T &a, T b){return (b > a) ? a = b, 1 : 0;} ll ksm(ll a, ll b) {if (b == 0) return 1; ll ns = ksm(a, b >> 1); ns = ns * ns % mod; if (b & 1) ns = ns * a % mod; return ns;} using namespace std; int n, m; const int maxn = 3005; ll dp[maxn][maxn][2]; // a filled, b valid characters, valid or not int main() { cin >> n >> m; dp[0][0][1] = 1; for (int i = 0; i < n; i++) { for (int j = 0; j <= i; j++) for (int k = 0; k < 2; k++) { // valid chara dp[i + 1][j][1] = (dp[i + 1][j][1] + dp[i][j][k] * j) % mod; dp[i + 1][j + k][0] = (dp[i + 1][j + k][0] + (m - j) * dp[i][j][k]) % mod; } } ll ans = 0; for (int j = 0; j <= n; j++) ans += dp[n][j][1]; ans %= mod; if (ans < 0) ans += mod; cout << ans << endl; return (0-0); // <3 cxr } |
English