1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = a; i <= b; i++)
#define per(i, a, b) for (int i = b; a <= i; i--)
#define cat(x) cerr << #x << " = " << x << '\n';
using ll = long long;
using namespace std;

int n, cnt[3], last[3];
char s[300111];
ll res;
map<int, vector<int>> mp[3];
map<pair<int, int>, vector<int>> abc;

int main() {
	cin.tie(0)->sync_with_stdio(0);

	cin >> s + 1;
	n = strlen(s + 1);

	abc[make_pair(0, 0)].push_back(0);
	mp[0][0].push_back(0);
	mp[1][0].push_back(0);
	mp[2][0].push_back(0);

	for (int i = 1; i <= n; i++) {
		int id = s[i] - 'a';
		cnt[id]++;
		last[id] = i;

		int x = min(last[(id + 1) % 3], last[(id + 2) % 3]);
		int y = max(last[(id + 1) % 3], last[(id + 2) % 3]);

		res += i - y;

		int z = i - 2 * cnt[id];
		res += lower_bound(mp[id][z].begin(), mp[id][z].end(), y) - lower_bound(mp[id][z].begin(), mp[id][z].end(), x);

		pair<int, int> c = {cnt[0] - cnt[1], cnt[1] - cnt[2]};
		res += lower_bound(abc[c].begin(), abc[c].end(), x) - abc[c].begin();

		abc[c].push_back(i);
		for (int j = 0; j < 3; j++) {
			mp[j][i - 2 * cnt[j]].push_back(i);
		}
	}

	cout << res << "\n";
	return 0;
}