1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <bits/stdc++.h>
using namespace std;

const int N = 3e5 + 5;

int n;
char S[N];
map<pair<int, int>, int> kub;
map<pair<int, int>, int> kub_a, kub_b, kub_c;

int main() {
	scanf("%s", &S);
	int n = strlen(S);
	kub[make_pair(0, 0)] = 1;
	kub_a[make_pair(0, 0)] = 1;
	kub_b[make_pair(0, 0)] = 1;
	kub_c[make_pair(0, 0)] = 1;
	int d1 = 0, d2 = 0, d3 = 0;
	long long res = 0;
	for (int i = 0; i < n; i++) {
		if (S[i] == 'a') {
			d1++;
		}
		if (S[i] == 'b') {
			d2++;
		}
		if (S[i] == 'c') {
			d3++;
		}
		res += kub[make_pair(d1 - d2, d1 - d3)];
		kub[make_pair(d1 - d2, d1 - d3)]++;
		res += kub_a[make_pair(d1, d2 - d3)];
		kub_a[make_pair(d1, d2 - d3)]++;
		res += kub_b[make_pair(d2, d1 - d3)];
		kub_b[make_pair(d2, d1 - d3)]++;
		res += kub_c[make_pair(d3, d1 - d2)];
		kub_c[make_pair(d3, d1 - d2)]++;
	}
	res++;
	int last = 0;
	for (int i = 1; i < n; i++) {
		if (S[i] != S[i - 1]) {
			last = i;
		}
		res += (i - last + 1);
	}
	printf("%lld\n", res);
}