1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
// #include <bits/stdc++.h>
#define FAST ios_base::sync_with_stdio(false), cin.tie(0), cout.tie(0);
#include <iostream>
#include <vector>
#include <unordered_map>
#include <algorithm>
using namespace std;

int main()
{
    FAST
    int n;
    cin >> n;
    vector<int64_t> nums(n); 
    unordered_map<int64_t, uint64_t> nums_count;
    nums_count.reserve(8*n*n);
    int64_t sum;
    for (int i = 0; i < n ; i++) {
        cin >> nums[i];
    }
    int64_t max = nums[0];
    int64_t min = nums[0];
    for (int i = 0; i < n ; i++) {
        sum = 0 ;
        for (int j = i ; j < n ; j++) {
            sum+=nums[j];
            nums_count[sum]++;
        }
    }
    
    vector<int64_t> numz(nums_count.size());
    vector<uint64_t> countz(nums_count.size());
    int i = 0 ;
    for (auto v: nums_count) {
        numz[i++] = v.first;
    }
    sort(numz.begin(), numz.end());
    for (int i = 0 ; i< numz.size(); i++) {
        countz[i] = nums_count[numz[i]];
    }
    uint64_t total = 0;
    int64_t v3;
    uint64_t v3c;

    if (nums_count.count(0) && nums_count[0] >=3) {
        v3c = nums_count[0];
        if (v3c >= 3) {
            total+=(v3c*(v3c-1)/2)*(v3c-2)/3;
        }
    }


    
    i = 0;
    int maxI = numz.size()-1;
    while(i<maxI && numz[i]+2*numz[maxI]<0) i++;
    // if ()
    // v3 = -(numz[i])*2;
    for (; i< maxI ; i++) {
        while (maxI>i && numz[maxI]+2*numz[i]>0)maxI--;
        for (int j = i+1 ; j < maxI ; j++ ){
            v3 = -(numz[i]+numz[j]);
            if (v3 < numz[j]) {
                break;
            }
            if (v3 > numz[maxI] || !nums_count.count(v3)) {
                continue;
            }
            
            if (numz[j] == v3) {
                total+=(countz[j]*(countz[j]-1)/2)*countz[i];
            } else {
                total+=(countz[i]*countz[j]* nums_count[v3]);
            }
        }
    }
    cout << total;
}