#include <cstdio> #include <cstring> #include <cassert> #include <algorithm> using int64 = long long; const int N = 200 + 10; const int64 inf = 1ll << 61; bool mark[61][N][N]; int64 dp[61][N][N]; int64 s[N], m; int n; int64 solve(int x, int l, int r) { if (x == -1) assert(l == r); if (x == -1) return 0; if (mark[x][l][r]) return dp[x][l][r]; mark[x][l][r] = 1; int64 ret = -inf; for (int u = 0; u <= r - l + 1; ++u) { int v = r - l + 1 - u; if (std::max(u, v) > (1ll << x)) continue; int64 tmp = 0; if (u) tmp += solve(x - 1, l, l + u - 1); if (v) tmp += solve(x - 1, l + u, r) + s[r] - s[l + u - 1]; ret = std::max(ret, tmp); } return dp[x][l][r] = ret; } bool mark2[61][N]; int64 dp2[61][N]; int64 solve2(int x, int l) { assert(l <= n); if (x == -1 && l != n) return -inf; if (x == -1) return 0; if (mark2[x][l]) return dp2[x][l]; mark2[x][l] = 1; int64 ret = -inf; int o = m >> x & 1; if (!o) ret = solve2(x - 1, l); else { for (int u = 0; u <= n - l + 1; ++u) { if ((1ll << x) < u) continue; int64 tmp = 0; if (u) tmp += solve(x - 1, l, l + u - 1); if (u != n - l + 1) { if (solve2(x - 1, l + u) == -inf) continue; tmp += solve2(x - 1, l + u) + s[n] - s[l + u - 1]; } ret = std::max(ret, tmp); } } return dp2[x][l] = ret; } int main() { scanf("%d%lld", &n, &m); for (int i = 1; i <= n; ++i) { scanf("%lld", &s[i]); s[i] += s[i - 1]; } memset(mark, 0, sizeof(mark)); memset(mark2, 0, sizeof(mark2)); printf("%lld\n", solve2(60, 1)); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | #include <cstdio> #include <cstring> #include <cassert> #include <algorithm> using int64 = long long; const int N = 200 + 10; const int64 inf = 1ll << 61; bool mark[61][N][N]; int64 dp[61][N][N]; int64 s[N], m; int n; int64 solve(int x, int l, int r) { if (x == -1) assert(l == r); if (x == -1) return 0; if (mark[x][l][r]) return dp[x][l][r]; mark[x][l][r] = 1; int64 ret = -inf; for (int u = 0; u <= r - l + 1; ++u) { int v = r - l + 1 - u; if (std::max(u, v) > (1ll << x)) continue; int64 tmp = 0; if (u) tmp += solve(x - 1, l, l + u - 1); if (v) tmp += solve(x - 1, l + u, r) + s[r] - s[l + u - 1]; ret = std::max(ret, tmp); } return dp[x][l][r] = ret; } bool mark2[61][N]; int64 dp2[61][N]; int64 solve2(int x, int l) { assert(l <= n); if (x == -1 && l != n) return -inf; if (x == -1) return 0; if (mark2[x][l]) return dp2[x][l]; mark2[x][l] = 1; int64 ret = -inf; int o = m >> x & 1; if (!o) ret = solve2(x - 1, l); else { for (int u = 0; u <= n - l + 1; ++u) { if ((1ll << x) < u) continue; int64 tmp = 0; if (u) tmp += solve(x - 1, l, l + u - 1); if (u != n - l + 1) { if (solve2(x - 1, l + u) == -inf) continue; tmp += solve2(x - 1, l + u) + s[n] - s[l + u - 1]; } ret = std::max(ret, tmp); } } return dp2[x][l] = ret; } int main() { scanf("%d%lld", &n, &m); for (int i = 1; i <= n; ++i) { scanf("%lld", &s[i]); s[i] += s[i - 1]; } memset(mark, 0, sizeof(mark)); memset(mark2, 0, sizeof(mark2)); printf("%lld\n", solve2(60, 1)); return 0; } |