#include <bits/stdc++.h>
#define PB push_back
#define ST first
#define ND second
//#pragma GCC optimize ("O3")
//#pragma GCC target("tune=native")
//mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
//#include <ext/pb_ds/assoc_container.hpp>
//#include <ext/pb_ds/tree_policy.hpp>
//using namespace __gnu_pbds;
//typedef tree<int, null_type, less_equal<int>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
using namespace std;
using ll = long long;
using pi = pair<int, int>;
using vi = vector<int>;
const int nax = 3000 + 10, mod = 1e9 + 7;
int n, m;
int dp[2][nax][2];
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> m;
dp[0][0][1] = 1;
dp[1][1][0] = m;
for(int i = 2; i <= n; ++i) {
for(int j = 0; j <= n; ++j) {
dp[i&1][j][0] = 0;
}
int bt = i & 1;
for(int j = 0; j <= n; ++j) {
if(j > 0)
dp[bt][j][0] = ((ll)dp[bt^1][j - 1][1] * (m - j+1)) % mod;
dp[bt][j][0] = (dp[bt][j][0] + (ll)dp[bt^1][j][0] * (m - j)) % mod;
dp[bt][j][1] = ((ll)(dp[bt^1][j][0] + dp[bt^1][j][1]) * j) % mod;
}
}
int ans = 0;
for(int i = 0; i <= n; ++i) {
ans = (ans + dp[n&1][i][1]) % mod;
}
if(ans < 0) ans += mod;
cout << ans;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | #include <bits/stdc++.h> #define PB push_back #define ST first #define ND second //#pragma GCC optimize ("O3") //#pragma GCC target("tune=native") //mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); //#include <ext/pb_ds/assoc_container.hpp> //#include <ext/pb_ds/tree_policy.hpp> //using namespace __gnu_pbds; //typedef tree<int, null_type, less_equal<int>, rb_tree_tag, tree_order_statistics_node_update> ordered_set; using namespace std; using ll = long long; using pi = pair<int, int>; using vi = vector<int>; const int nax = 3000 + 10, mod = 1e9 + 7; int n, m; int dp[2][nax][2]; int main() { ios_base::sync_with_stdio(0); cin.tie(0); cin >> n >> m; dp[0][0][1] = 1; dp[1][1][0] = m; for(int i = 2; i <= n; ++i) { for(int j = 0; j <= n; ++j) { dp[i&1][j][0] = 0; } int bt = i & 1; for(int j = 0; j <= n; ++j) { if(j > 0) dp[bt][j][0] = ((ll)dp[bt^1][j - 1][1] * (m - j+1)) % mod; dp[bt][j][0] = (dp[bt][j][0] + (ll)dp[bt^1][j][0] * (m - j)) % mod; dp[bt][j][1] = ((ll)(dp[bt^1][j][0] + dp[bt^1][j][1]) * j) % mod; } } int ans = 0; for(int i = 0; i <= n; ++i) { ans = (ans + dp[n&1][i][1]) % mod; } if(ans < 0) ans += mod; cout << ans; } |
English