#include <iostream> #include <unordered_map> #include <limits> using namespace std; using LL = long long; const int MAX_N = 200; unordered_map<LL, LL> mem[MAX_N]; LL a[MAX_N]; int max_bits(LL m) { if(m <= 0) return 0; int t = 63 - __builtin_clzll(m); if(__builtin_popcountll(m) == (63 - __builtin_clzll(m)) + 1) t++; return t; } int max_le_with_k_bits(LL n, int k) { int c = __builtin_popcountll(n); while(c > k) { n &= n - 1; c--; } if(c == k) return n; int t = 63 - __builtin_clzll(n); return (1ULL << t) - (1ULL << (t - k)); } LL f(int n, LL m) { if(n < 0) return 0; if(m < 0) return numeric_limits<LL>::min(); if(mem[n].count(m)) return mem[n][m]; int mx = max_bits(m); LL res = numeric_limits<LL>::min(); for(int k = 0; k <= mx; ++k) { LL t = f(n - 1, max_le_with_k_bits(m, k) - 1); if(t == numeric_limits<LL>::min()) continue; res = max(res, t + k * a[n]); } return mem[n][m] = res; } int main() { int n; LL m; cin >> n >> m; for(int i = 0; i < n; ++i) cin >> a[i]; cout << f(n - 1, m) << endl; return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | #include <iostream> #include <unordered_map> #include <limits> using namespace std; using LL = long long; const int MAX_N = 200; unordered_map<LL, LL> mem[MAX_N]; LL a[MAX_N]; int max_bits(LL m) { if(m <= 0) return 0; int t = 63 - __builtin_clzll(m); if(__builtin_popcountll(m) == (63 - __builtin_clzll(m)) + 1) t++; return t; } int max_le_with_k_bits(LL n, int k) { int c = __builtin_popcountll(n); while(c > k) { n &= n - 1; c--; } if(c == k) return n; int t = 63 - __builtin_clzll(n); return (1ULL << t) - (1ULL << (t - k)); } LL f(int n, LL m) { if(n < 0) return 0; if(m < 0) return numeric_limits<LL>::min(); if(mem[n].count(m)) return mem[n][m]; int mx = max_bits(m); LL res = numeric_limits<LL>::min(); for(int k = 0; k <= mx; ++k) { LL t = f(n - 1, max_le_with_k_bits(m, k) - 1); if(t == numeric_limits<LL>::min()) continue; res = max(res, t + k * a[n]); } return mem[n][m] = res; } int main() { int n; LL m; cin >> n >> m; for(int i = 0; i < n; ++i) cin >> a[i]; cout << f(n - 1, m) << endl; return 0; } |