// #include <bits/stdc++.h>
#define FAST ios_base::sync_with_stdio(false), cin.tie(0), cout.tie(0);
#include <iostream>
#include <vector>
#include <unordered_map>
#include <algorithm>
using namespace std;
int main()
{
FAST
int n;
cin >> n;
vector<int64_t> nums(n);
unordered_map<int64_t, uint64_t> nums_count;
nums_count.reserve(8*n*n);
int64_t sum;
for (int i = 0; i < n ; i++) {
cin >> nums[i];
}
int64_t max = nums[0];
int64_t min = nums[0];
for (int i = 0; i < n ; i++) {
sum = 0 ;
for (int j = i ; j < n ; j++) {
sum+=nums[j];
nums_count[sum]++;
}
}
vector<int64_t> numz(nums_count.size());
vector<uint64_t> countz(nums_count.size());
int i = 0 ;
for (auto v: nums_count) {
numz[i++] = v.first;
}
sort(numz.begin(), numz.end());
for (int i = 0 ; i< numz.size(); i++) {
countz[i] = nums_count[numz[i]];
}
uint64_t total = 0;
int64_t v3;
uint64_t v3c;
if (nums_count.count(0) && nums_count[0] >=3) {
v3c = nums_count[0];
if (v3c >= 3) {
total+=(v3c*(v3c-1)/2)*(v3c-2)/3;
}
}
i = 0;
int maxI = numz.size()-1;
while(i<maxI && numz[i]+2*numz[maxI]<0) i++;
// if ()
// v3 = -(numz[i])*2;
for (; i< maxI ; i++) {
while (maxI>i && numz[maxI]+2*numz[i]>0)maxI--;
for (int j = i+1 ; j < maxI ; j++ ){
v3 = -(numz[i]+numz[j]);
if (v3 < numz[j]) {
break;
}
if (v3 > numz[maxI] || !nums_count.count(v3)) {
continue;
}
if (numz[j] == v3) {
total+=(countz[j]*(countz[j]-1)/2)*countz[i];
} else {
total+=(countz[i]*countz[j]* nums_count[v3]);
}
}
}
cout << total;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | // #include <bits/stdc++.h> #define FAST ios_base::sync_with_stdio(false), cin.tie(0), cout.tie(0); #include <iostream> #include <vector> #include <unordered_map> #include <algorithm> using namespace std; int main() { FAST int n; cin >> n; vector<int64_t> nums(n); unordered_map<int64_t, uint64_t> nums_count; nums_count.reserve(8*n*n); int64_t sum; for (int i = 0; i < n ; i++) { cin >> nums[i]; } int64_t max = nums[0]; int64_t min = nums[0]; for (int i = 0; i < n ; i++) { sum = 0 ; for (int j = i ; j < n ; j++) { sum+=nums[j]; nums_count[sum]++; } } vector<int64_t> numz(nums_count.size()); vector<uint64_t> countz(nums_count.size()); int i = 0 ; for (auto v: nums_count) { numz[i++] = v.first; } sort(numz.begin(), numz.end()); for (int i = 0 ; i< numz.size(); i++) { countz[i] = nums_count[numz[i]]; } uint64_t total = 0; int64_t v3; uint64_t v3c; if (nums_count.count(0) && nums_count[0] >=3) { v3c = nums_count[0]; if (v3c >= 3) { total+=(v3c*(v3c-1)/2)*(v3c-2)/3; } } i = 0; int maxI = numz.size()-1; while(i<maxI && numz[i]+2*numz[maxI]<0) i++; // if () // v3 = -(numz[i])*2; for (; i< maxI ; i++) { while (maxI>i && numz[maxI]+2*numz[i]>0)maxI--; for (int j = i+1 ; j < maxI ; j++ ){ v3 = -(numz[i]+numz[j]); if (v3 < numz[j]) { break; } if (v3 > numz[maxI] || !nums_count.count(v3)) { continue; } if (numz[j] == v3) { total+=(countz[j]*(countz[j]-1)/2)*countz[i]; } else { total+=(countz[i]*countz[j]* nums_count[v3]); } } } cout << total; } |
English