1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
#define MAKS 300010
using namespace std;
typedef long long int lld;
char s[MAKS];

vector<lld> f;
vector<int> g;

int main()
{
    scanf("%s",s);
    int n = strlen(s);
    lld wyn=0;
    // jedynki
    int seria=1;
    for(int i=1;i<n;i++)
    {
        if(s[i]==s[i-1])seria++;
        else
        {
            //printf("seria %d\n",seria);
            wyn += lld(seria)*lld(seria-1)/2LL + lld(seria);
            seria=1;
        }
    }
    //printf("seria %d\n",seria);
    wyn += lld(seria)*lld(seria-1)/2LL + lld(seria);

    // trójki
    int cnt[3]={0,0,0};
    f.push_back(0);
    for(int i=0;i<n;i++)
    {
        cnt[s[i]-'a']++;
        int m = min(cnt[0], min(cnt[1], cnt[2]));
        cnt[0]-=m; cnt[1]-=m; cnt[2]-=m;
        //printf("%d %d %d\n", cnt[0], cnt[1], cnt[2]);
        lld h = lld(cnt[0]) + lld(MAKS)*lld(cnt[1]) + lld(MAKS)*lld(MAKS)*lld(cnt[2]);
        f.push_back(h);
    }
    sort(f.begin(), f.end());
    seria=0;
    for(int i=0;i<f.size();i++)
    {
        if(i>0 && f[i]!=f[i-1])
        {
            wyn+=lld(seria)*lld(seria-1)/2LL;
            seria=1;
        }
        else seria++;
    }
    wyn+=lld(seria)*lld(seria-1)/2LL;

    f.clear();


    // dwójki
    for(char q='a';q<='c';q++)
    {
        char x;
        if(q=='a')x='b';
        else x='a';

        int i=0;
        while(i<n)
        {
            while(i<n && s[i]==q)i++;
            g.clear();
            g.push_back(0);
            int balans=0;
            while(i<n && s[i]!=q)
            {
                if(s[i]==x)balans++;
                else balans--;
                g.push_back(balans);
                i++;
            }
            sort(g.begin(), g.end());
            seria=0;
            for(int j=0;j<g.size();j++)
            {
                if(j>0 && g[j]!=g[j-1])
                {
                    wyn+=lld(seria)*lld(seria-1)/2LL;
                    seria=1;
                }
                else seria++;
            }
            wyn+=lld(seria)*lld(seria-1)/2LL;
        }
    }

    printf("%lld\n", wyn);
}