// #include <bits/stdc++.h> #define FAST ios_base::sync_with_stdio(false), cin.tie(0), cout.tie(0); #include <iostream> #include <vector> #include <unordered_map> #include <algorithm> using namespace std; int main() { FAST int n; cin >> n; vector<int64_t> nums(n); unordered_map<int64_t, uint64_t> nums_count; nums_count.reserve(8*n*n); int64_t sum; for (int i = 0; i < n ; i++) { cin >> nums[i]; } int64_t max = nums[0]; int64_t min = nums[0]; for (int i = 0; i < n ; i++) { sum = 0 ; for (int j = i ; j < n ; j++) { sum+=nums[j]; nums_count[sum]++; } } vector<int64_t> numz(nums_count.size()); vector<uint64_t> countz(nums_count.size()); int i = 0 ; for (auto v: nums_count) { numz[i++] = v.first; } sort(numz.begin(), numz.end()); for (int i = 0 ; i< numz.size(); i++) { countz[i] = nums_count[numz[i]]; } uint64_t total = 0; int64_t v3; uint64_t v3c; if (nums_count.count(0) && nums_count[0] >=3) { v3c = nums_count[0]; if (v3c >= 3) { total+=(v3c*(v3c-1)/2)*(v3c-2)/3; } } i = 0; int maxI = numz.size()-1; while(i<maxI && numz[i]+2*numz[maxI]<0) i++; // if () // v3 = -(numz[i])*2; for (; i< maxI ; i++) { while (maxI>i && numz[maxI]+2*numz[i]>0)maxI--; for (int j = i+1 ; j < maxI ; j++ ){ v3 = -(numz[i]+numz[j]); if (v3 < numz[j]) { break; } if (v3 > numz[maxI] || !nums_count.count(v3)) { continue; } if (numz[j] == v3) { total+=(countz[j]*(countz[j]-1)/2)*countz[i]; } else { total+=(countz[i]*countz[j]* nums_count[v3]); } } } cout << total; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | // #include <bits/stdc++.h> #define FAST ios_base::sync_with_stdio(false), cin.tie(0), cout.tie(0); #include <iostream> #include <vector> #include <unordered_map> #include <algorithm> using namespace std; int main() { FAST int n; cin >> n; vector<int64_t> nums(n); unordered_map<int64_t, uint64_t> nums_count; nums_count.reserve(8*n*n); int64_t sum; for (int i = 0; i < n ; i++) { cin >> nums[i]; } int64_t max = nums[0]; int64_t min = nums[0]; for (int i = 0; i < n ; i++) { sum = 0 ; for (int j = i ; j < n ; j++) { sum+=nums[j]; nums_count[sum]++; } } vector<int64_t> numz(nums_count.size()); vector<uint64_t> countz(nums_count.size()); int i = 0 ; for (auto v: nums_count) { numz[i++] = v.first; } sort(numz.begin(), numz.end()); for (int i = 0 ; i< numz.size(); i++) { countz[i] = nums_count[numz[i]]; } uint64_t total = 0; int64_t v3; uint64_t v3c; if (nums_count.count(0) && nums_count[0] >=3) { v3c = nums_count[0]; if (v3c >= 3) { total+=(v3c*(v3c-1)/2)*(v3c-2)/3; } } i = 0; int maxI = numz.size()-1; while(i<maxI && numz[i]+2*numz[maxI]<0) i++; // if () // v3 = -(numz[i])*2; for (; i< maxI ; i++) { while (maxI>i && numz[maxI]+2*numz[i]>0)maxI--; for (int j = i+1 ; j < maxI ; j++ ){ v3 = -(numz[i]+numz[j]); if (v3 < numz[j]) { break; } if (v3 > numz[maxI] || !nums_count.count(v3)) { continue; } if (numz[j] == v3) { total+=(countz[j]*(countz[j]-1)/2)*countz[i]; } else { total+=(countz[i]*countz[j]* nums_count[v3]); } } } cout << total; } |