1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#include <map>
#include <queue>

using namespace std;

typedef long long LL;

#define M 300011
#define P 3
#define PI pair<int, int>
#define MP make_pair
#define VI vector<int>
#define PB push_back

int n;
char word[M];
int last[P], lCount[P];
map<PI, int> diff;
vector<int> positions[3][2*M+1];
int cursor[3][2*M+1];

int main() {
	scanf("%s", word);
	n=strlen(word);

	LL result=0;

	last[0]=last[1]=last[2]=-1;
	positions[0][M].PB(-1);
	positions[1][M].PB(-1);
	positions[2][M].PB(-1);
	diff[MP(0, 0)]=1;

	for (int i=0; i<n; i++) {
		if (word[i]<'a' || word[i]>'c') {
			printf("INVALID word: word[%d]=%c\n", i, word[i]);
			return 1;
		}

		int curr=word[i]-'a';

		last[curr]=i;
		lCount[curr]++;

		// Words finishing at this point consisting of exactly one letter.
		int prev=-1;
		for (int a=0; a<P; a++)
			if (a!=curr && (prev==-1 || last[a]>last[prev])) prev=a;
		result+=(LL)(i-last[prev]);

		int other=3-curr-prev;

		// Words finishing at this point consisting of exactly two letters.
		if (last[prev]!=-1) {
			// We know that those letters are curr and prev.
			// We look for positions that are:
			// - greater or equal than last[other]
			// - that have equal difference in count of curr and prev

			int idx=curr+prev-1;
			int a=min(curr, prev), b=max(curr, prev);
			int d=lCount[a]-lCount[b];

			while (cursor[idx][d+M]<positions[idx][d+M].size()) {
				int x=positions[idx][d+M][cursor[idx][d+M]];
				if (x<last[other]) cursor[idx][d+M]++;
				else break;
			}
			result+=(LL)(positions[idx][d+M].size()-cursor[idx][d+M]);
		}
		positions[0][lCount[0]-lCount[1]+M].PB(i);
		positions[1][lCount[0]-lCount[2]+M].PB(i);
		positions[2][lCount[1]-lCount[2]+M].PB(i);

		// Words finishing at this point consisting of exactly three letters.
		PI currDiff=MP(lCount[0]-lCount[1], lCount[0]-lCount[2]);
		result+=(LL)diff[currDiff];
		diff[currDiff]++;
	}

	printf("%lld\n", result);
	return 0;
}