1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <cstdio>
#include <string.h>
#include <map>

using namespace std;

typedef long long LL;


const int MAXN = 300000+1;

char s[MAXN];

int main() {
    scanf("%s", &s);
    int n = strlen(s);
    LL res = 0;

    int c[3];
    c[0] = c[1] = c[2] = 0;

    map<int, int> m2[3];
    m2[0][0] = m2[1][0] = m2[2][0] =  1;
    map<pair<int, int>, int> m3;
    m3[pair<int, int>(0, 0)] = 1;


    LL m1 = 1;
    for (int i = 0; i < n; i++) {

        if(s[i] != s[i+1]) {
            res += (m1*(m1+1))/2;
            m1 = 0;
        }
        m1++;;

        c[s[i] - 'a']++;
        int ab = c[0] - c[1];
        int bc = c[1] - c[2];
        int ca = c[2] - c[0];

        m2[s[i]-'a'].clear();

        for(int j = 0; j < 3; j++) {
            int p = (j == 0) ? bc : ((j == 1) ? ca : ab);
            if(m2[j].find(p) == m2[j].end()) {
                m2[j][p] = 0;
            }
            res += m2[j][p];
            m2[j][p]++;
        }
        auto p = pair<int, int> (ab, bc);
        if(m3.find(p) == m3.end()) {
            m3[p] = 0;
        }
        res += m3[p];
        m3[p]++;
    }
    printf("%lld\n", res);
}