1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include<cstdio>
#include<cstdint>
#include<vector>
#include<map>
#include<set>
#include<utility>

using namespace std;

int main() {
	uint64_t sum = 0;
	vector<char> word;
	word.reserve(300008);
	int64_t cnt = 0;
	char last = 0;
	while (1) {
		char c;
		scanf("%c", &c);
		if (c == '\n') {
			sum += cnt * (cnt + 1) / 2;
			break;
		}
		if (c == last) {
			++cnt;
		}
		else {
			sum += cnt * (cnt + 1) / 2;
			cnt = 1;
			last = c;
		}
		word.push_back(c);
	}
	vector<pair<char, char> > pairs;
	pairs.push_back(make_pair('a', 'b'));
	pairs.push_back(make_pair('b', 'c'));
	pairs.push_back(make_pair('c', 'a'));
	for (int i = 0; i < 3; ++i) {
		char first = pairs[i].first;
		char second = pairs[i].second;
		int delta = 0;
		map<int, int> deltas;
		deltas[0] = 1;
		for (unsigned int j = 0; j < word.size(); ++j) {
			if (word[j] != first && word[j] != second) {
				delta = 0;
				deltas.clear();
				deltas[0] = 1;
			}
			else {
				if (word[j] == first) {
					++delta;
				}
				else {
					--delta;
				}
				map<int, int>::iterator it = deltas.find(delta);
				if (it == deltas.end()) {
					deltas[delta] = 1;
				}
				else {
					++deltas[delta];
				}
				sum += deltas[delta] - 1;
			}
		}
	}
	map<pair<int, pair<int, int> >, int> deltas;
	deltas[make_pair(0, make_pair(0, 0))] = 1;
	int amb = 0;
	int bmc = 0;
	int cma = 0;
	for (unsigned int j = 0; j < word.size(); ++j) {
		if (word[j] == 'a') {
			++amb;
			--cma;
		}
		else if (word[j] == 'b') {
			++bmc;
			--amb;
		}
		else {
			++cma;
			--bmc;
		}
		pair<int, pair<int, int> > delta = make_pair(amb, make_pair(bmc, cma));
		map<pair<int, pair<int, int> >, int>::iterator it = deltas.find(delta);
		if (it == deltas.end()) {
			deltas[delta] = 1;
		}
		else {
			++deltas[delta];
		}
		sum += deltas[delta] - 1;
	}
	printf("%lu\n", sum);
	return 0;
}