1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
#include <iostream>
#include <vector>
#include <string>
#include <numeric>
#include <algorithm>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <tuple>

int main()
{
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);

    int n;
    std::cin >> n;

    int a[n];

    for (int i = 0; i < n; i++)
    {
        std::cin >> a[i];
    }

    std::map<int, long long> s;

    for (int i = 0; i < n; i++)
    {
        int sum = 0;
        for (int j = i; j < n; j++)
        {
            sum += a[j];
            s[sum]++;
        }
    }

    long long result = 0;

    if (s.find(0) != s.end())
    {
        auto val = s[0];
        result += val * (val - 1) * (val - 2) / 6;
    }

    for (auto it = s.begin(); it != s.end(); it++)
    {
        int k = it->first;
        int v = it->second;

        if (k == 0)
        {
            continue;
        }

        if (v >= 2)
        {
            if (s.find(-2 * k) != s.end())
            {
                result += v * (v - 1) / 2 * s[-2 * k];
            }
        }
    }

    int i = 0;

    for (auto it = s.cbegin(); it != s.cend(); ++it)
    {
        int k = it->first;

        auto it2 = std::next(it);
        auto it3 = std::prev(s.cend());

        while (it2 != s.cend() && it2->first < it3->first)
        {
            int sum = k + it2->first + it3->first;

            if (sum == 0)
            {
                result += it->second * it2->second * it3->second;
                ++it2;
                --it3;
            }
            else if (sum < 0)
            {
                it2 = s.lower_bound(it2->first - sum);
                if (it2 == s.cend())
                {
                    break;
                }
            }
            else
            {
                it3 = s.upper_bound(it3->first - sum);
                if (it3 == s.cbegin())
                {
                    break;
                }
                it3 = std::prev(it3);
            }
        }
    }
    
    std::cout << result << std::endl;

    return 0;
}