#include <bits/stdc++.h> using namespace std; using ll = long long; struct result { ll res; ll m; bool operator<(result const& b) const { if (m != b.m) return m < b.m; return res < b.res; } }; set<result> current, nxt; int n, lg; ll m; void maybeInsert(result r) { auto it = nxt.upper_bound(r); if (it != nxt.end() && it->m == r.m) { return; } while (it != nxt.end() && it->res <= r.res) { it = nxt.erase(it); } if (it == nxt.begin()) { nxt.insert(it, r); return; } it--; if (it->m == r.m) { it = nxt.erase(it); if (it == nxt.begin()) { return; } it--; } if (it->res >= r.res) { return; } it++; nxt.insert(it, r); } int main() { scanf("%d %lld", &n, &m); for (ll t = m; t; t /= 2) { lg++; } for (int i = 0; i < n; i++) { ll a; scanf("%lld", &a); if (i == 0) { current.insert({0, 0}); for (int k = 0; k < lg; k++) { ll t = (1ll<<(k+1)) - 1; if (t <= m) { current.insert({a * (k+1), t}); } } } else { for (auto r : current) { vector<int> used(lg+1, 0); int pc = __builtin_popcountll(r.m); for (int k2 = 0; k2 < lg; k2++) { ll m1 = r.m | ((1ll<<(k2+1)) - 1); pc += (r.m >> k2) % 2 == 0; if (m1 > r.m && m1 <= m) { if (!used[pc]) { maybeInsert({r.res + a * pc, m1}); used[pc] = 1; } } } for (int k = 0; k < lg; k++) { ll m2 = ((r.m >> k) | 1) << k; if (m2 > r.m && m2 <= m) { pc = __builtin_popcountll(m2); int pc0 = pc; if (!used[pc]) { maybeInsert({r.res + a * pc, m2}); used[pc] = 1; for (int k2 = 0; k2 < lg; k2++) { ll m1 = m2 | ((1ll<<(k2+1)) - 1); pc += (m2 >> k2) % 2 == 0; if (m1 <= m) { if (!used[pc]) { maybeInsert({r.res + a * pc, m1}); used[pc] = 1; } else if (pc >= pc0) break; } } } } } } current.swap(nxt); nxt.clear(); } } printf("%lld\n", accumulate(current.begin(), current.end(), -4000000000000000000ll, [](ll res, result r) { return max(res, r.res); })); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | #include <bits/stdc++.h> using namespace std; using ll = long long; struct result { ll res; ll m; bool operator<(result const& b) const { if (m != b.m) return m < b.m; return res < b.res; } }; set<result> current, nxt; int n, lg; ll m; void maybeInsert(result r) { auto it = nxt.upper_bound(r); if (it != nxt.end() && it->m == r.m) { return; } while (it != nxt.end() && it->res <= r.res) { it = nxt.erase(it); } if (it == nxt.begin()) { nxt.insert(it, r); return; } it--; if (it->m == r.m) { it = nxt.erase(it); if (it == nxt.begin()) { return; } it--; } if (it->res >= r.res) { return; } it++; nxt.insert(it, r); } int main() { scanf("%d %lld", &n, &m); for (ll t = m; t; t /= 2) { lg++; } for (int i = 0; i < n; i++) { ll a; scanf("%lld", &a); if (i == 0) { current.insert({0, 0}); for (int k = 0; k < lg; k++) { ll t = (1ll<<(k+1)) - 1; if (t <= m) { current.insert({a * (k+1), t}); } } } else { for (auto r : current) { vector<int> used(lg+1, 0); int pc = __builtin_popcountll(r.m); for (int k2 = 0; k2 < lg; k2++) { ll m1 = r.m | ((1ll<<(k2+1)) - 1); pc += (r.m >> k2) % 2 == 0; if (m1 > r.m && m1 <= m) { if (!used[pc]) { maybeInsert({r.res + a * pc, m1}); used[pc] = 1; } } } for (int k = 0; k < lg; k++) { ll m2 = ((r.m >> k) | 1) << k; if (m2 > r.m && m2 <= m) { pc = __builtin_popcountll(m2); int pc0 = pc; if (!used[pc]) { maybeInsert({r.res + a * pc, m2}); used[pc] = 1; for (int k2 = 0; k2 < lg; k2++) { ll m1 = m2 | ((1ll<<(k2+1)) - 1); pc += (m2 >> k2) % 2 == 0; if (m1 <= m) { if (!used[pc]) { maybeInsert({r.res + a * pc, m1}); used[pc] = 1; } else if (pc >= pc0) break; } } } } } } current.swap(nxt); nxt.clear(); } } printf("%lld\n", accumulate(current.begin(), current.end(), -4000000000000000000ll, [](ll res, result r) { return max(res, r.res); })); } |