1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

char s[300005];
pair<int, int> w[300005];
int n, licz[600010];
void dodaj(long long *wynik, char l, char a, char b) {
    int *wt = licz + 300005;
    wt[0] = 1;
    int t = 0, minz = 0, maxz = 0;
    for (int i = 1; i<=n; i++) {
        if (s[i] == a) {
            t++; *wynik += wt[t]; wt[t]++;
            if (t > maxz) maxz = t;
        }
        if (s[i] == b) {
            t--; *wynik += wt[t]; wt[t]++;
            if (t < minz) minz = t;
        }
        if (s[i] == l) {
            for (int j = minz; j <= maxz; j++) wt[j] = 0;
            minz = maxz = t = 0;
            wt[0]=1;
        }
    }
    for (int j = minz; j <= maxz; j++) wt[j] = 0;
}
int main() {
    scanf("%s", s+1);
    n = strlen(s+1);
    
    w[0] = make_pair(0, 0);
    int len = 0;
    long long wynik = 0;
    for (int i = 1; i<=n; i++) {
        w[i] = w[i-1];
        switch (s[i]) {
            case 'a' : w[i].first++; break;
            case 'b' : w[i].second++; break;
            case 'c' : w[i].first--; w[i].second--; break;
        }
        if (i == 1 || s[i] == s[i-1]) len++;
        else len = 1;
        wynik += len;
    }
    //printf("%lld\n",wynik);
    sort(w, w+n+1);
    int ost = 0;
    for (int i=1; i<=n; i++) {
        if (w[i] != w[i-1]) {
            int len = i - ost;
            wynik += (long long) len * (len - 1) / 2;
            ost = i;
        }
    }
    n -= ost;
    wynik += (long long) (n+1) * n / 2;
    n+=ost;
    //printf("%lld\n",wynik);
    dodaj(&wynik, 'a', 'b', 'c');
    dodaj(&wynik, 'b', 'c', 'a');
    dodaj(&wynik, 'c', 'a', 'b');
    printf("%lld\n", wynik);
}