#include <iostream> #include <vector> #include <bitset> #include <unordered_map> #include <unordered_set> #include <tuple> #include <utility> #include <optional> class Hash { public: size_t operator()(std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> const &tuple) const { return std::get<0>(tuple) * std::get<2>(tuple) + std::get<1>(tuple) * std::get<3>(tuple); } }; std::vector<int64_t> input; std::unordered_map<std::tuple<uint64_t, uint64_t, uint64_t, uint64_t>, std::optional<int64_t>, Hash> results; std::optional<int64_t> split(uint64_t begin, uint64_t end, uint64_t low, uint64_t high) { if (low - high < end - begin) return {}; if (end - begin == 1){ std::bitset<64> b = (low + high) / 2; int64_t res = input[begin] * b.count(); results.emplace(std::make_tuple(begin, end, low, high), std::optional(res)); return std::optional(res); } auto result = results.find(std::make_tuple(begin, end, low, high)); if (result != results.end()) return result->second; uint64_t lower = (2 * low + high) / 3; uint64_t higher = (low + 2 * high) / 3; std::optional<int64_t> l1,l2,r1,r2, best; l1 = split(begin, (begin + end) / 2, low , lower); l2 = split((begin + end) /2, end, lower, high); r1 = split(begin, (begin + end) / 2, low , higher); r2 = split((begin + end) /2, end, higher, high); bool ls = (!l1.has_value()) || (!l2.has_value()); bool rs = (!r1.has_value()) || (!r2.has_value()); if (ls && rs) { best = {}; }else if (ls) { best = std::optional<int64_t>(split(begin, (begin + end) / 2, low, higher).value() + split((begin + end) / 2, end, higher, high).value()); } else if (rs) { best = std::optional<int64_t>(split(begin, (begin + end) / 2, low, lower).value() + split((begin + end) / 2, end, lower, high).value()); } else if (l1.value() + l2.value() > r1.value() + r2.value()) { best = std::optional<int64_t>(split(begin, (begin + end) / 2, low, lower).value() + split((begin + end) / 2, end, lower, high).value()); } else { best = std::optional<int64_t>(split(begin, (begin + end) / 2, low, higher).value() + split((begin + end) / 2, end, higher, high).value()); } results.emplace(std::make_tuple(begin, end, low, high), best); return best; } int main() { std::ios_base::sync_with_stdio(false); int64_t n, m, tmp1; std::cin >> n >> m; for (int i = 0; i < n; ++i) std::cin >> tmp1, input.push_back(tmp1); std::cout << split(0, n, 1,m + 1).value() << std::endl; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | #include <iostream> #include <vector> #include <bitset> #include <unordered_map> #include <unordered_set> #include <tuple> #include <utility> #include <optional> class Hash { public: size_t operator()(std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> const &tuple) const { return std::get<0>(tuple) * std::get<2>(tuple) + std::get<1>(tuple) * std::get<3>(tuple); } }; std::vector<int64_t> input; std::unordered_map<std::tuple<uint64_t, uint64_t, uint64_t, uint64_t>, std::optional<int64_t>, Hash> results; std::optional<int64_t> split(uint64_t begin, uint64_t end, uint64_t low, uint64_t high) { if (low - high < end - begin) return {}; if (end - begin == 1){ std::bitset<64> b = (low + high) / 2; int64_t res = input[begin] * b.count(); results.emplace(std::make_tuple(begin, end, low, high), std::optional(res)); return std::optional(res); } auto result = results.find(std::make_tuple(begin, end, low, high)); if (result != results.end()) return result->second; uint64_t lower = (2 * low + high) / 3; uint64_t higher = (low + 2 * high) / 3; std::optional<int64_t> l1,l2,r1,r2, best; l1 = split(begin, (begin + end) / 2, low , lower); l2 = split((begin + end) /2, end, lower, high); r1 = split(begin, (begin + end) / 2, low , higher); r2 = split((begin + end) /2, end, higher, high); bool ls = (!l1.has_value()) || (!l2.has_value()); bool rs = (!r1.has_value()) || (!r2.has_value()); if (ls && rs) { best = {}; }else if (ls) { best = std::optional<int64_t>(split(begin, (begin + end) / 2, low, higher).value() + split((begin + end) / 2, end, higher, high).value()); } else if (rs) { best = std::optional<int64_t>(split(begin, (begin + end) / 2, low, lower).value() + split((begin + end) / 2, end, lower, high).value()); } else if (l1.value() + l2.value() > r1.value() + r2.value()) { best = std::optional<int64_t>(split(begin, (begin + end) / 2, low, lower).value() + split((begin + end) / 2, end, lower, high).value()); } else { best = std::optional<int64_t>(split(begin, (begin + end) / 2, low, higher).value() + split((begin + end) / 2, end, higher, high).value()); } results.emplace(std::make_tuple(begin, end, low, high), best); return best; } int main() { std::ios_base::sync_with_stdio(false); int64_t n, m, tmp1; std::cin >> n >> m; for (int i = 0; i < n; ++i) std::cin >> tmp1, input.push_back(tmp1); std::cout << split(0, n, 1,m + 1).value() << std::endl; } |