#include <iostream>
#include <unordered_map>
#include <limits>
using namespace std;
using LL = long long;
const int MAX_N = 200;
unordered_map<LL, LL> mem[MAX_N];
LL a[MAX_N];
int max_bits(LL m) {
if(m <= 0) return 0;
int t = 63 - __builtin_clzll(m);
if(__builtin_popcountll(m) == (63 - __builtin_clzll(m)) + 1) t++;
return t;
}
int max_le_with_k_bits(LL n, int k) {
int c = __builtin_popcountll(n);
while(c > k) { n &= n - 1; c--; }
if(c == k) return n;
int t = 63 - __builtin_clzll(n);
return (1ULL << t) - (1ULL << (t - k));
}
LL f(int n, LL m) {
if(n < 0) return 0;
if(m < 0) return numeric_limits<LL>::min();
if(mem[n].count(m)) return mem[n][m];
int mx = max_bits(m);
LL res = numeric_limits<LL>::min();
for(int k = 0; k <= mx; ++k) {
LL t = f(n - 1, max_le_with_k_bits(m, k) - 1);
if(t == numeric_limits<LL>::min()) continue;
res = max(res, t + k * a[n]);
}
return mem[n][m] = res;
}
int main() {
int n;
LL m;
cin >> n >> m;
for(int i = 0; i < n; ++i) cin >> a[i];
cout << f(n - 1, m) << endl;
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | #include <iostream> #include <unordered_map> #include <limits> using namespace std; using LL = long long; const int MAX_N = 200; unordered_map<LL, LL> mem[MAX_N]; LL a[MAX_N]; int max_bits(LL m) { if(m <= 0) return 0; int t = 63 - __builtin_clzll(m); if(__builtin_popcountll(m) == (63 - __builtin_clzll(m)) + 1) t++; return t; } int max_le_with_k_bits(LL n, int k) { int c = __builtin_popcountll(n); while(c > k) { n &= n - 1; c--; } if(c == k) return n; int t = 63 - __builtin_clzll(n); return (1ULL << t) - (1ULL << (t - k)); } LL f(int n, LL m) { if(n < 0) return 0; if(m < 0) return numeric_limits<LL>::min(); if(mem[n].count(m)) return mem[n][m]; int mx = max_bits(m); LL res = numeric_limits<LL>::min(); for(int k = 0; k <= mx; ++k) { LL t = f(n - 1, max_le_with_k_bits(m, k) - 1); if(t == numeric_limits<LL>::min()) continue; res = max(res, t + k * a[n]); } return mem[n][m] = res; } int main() { int n; LL m; cin >> n >> m; for(int i = 0; i < n; ++i) cin >> a[i]; cout << f(n - 1, m) << endl; return 0; } |
English