#include <bits/stdc++.h>
using namespace std;
using ll = long long;
char s[1333333];
int n;
template<>
struct std::hash<pair<int, int>> {
size_t operator()(pair<int, int> p) const {
return p.first * 12345 ^ p.second;
}
};
ll count_three() {
ll ret = 0;
unordered_map<pair<int, int>, int> m;
pair<int, int> p{0, 0};
m[p]++;
for (int i = 0; i < n; i++) {
if (s[i] == 'a') {
p.first--;
p.second--;
} else if (s[i] == 'b') {
p.first++;
} else {
p.second++;
}
ret += m[p];
m[p]++;
}
return ret;
}
ll count_two(char a, char b) {
ll ret = 0;
unordered_map<int, int> m;
m[0]++;
int p = 0;
for (int i = 0; i < n; i++) {
if (s[i] == a) {
p++;
} else if (s[i] == b) {
p--;
} else {
m.clear();
p = 0;
}
ret += m[p];
m[p]++;
}
return ret;
}
ll count_one() {
ll ret = 0;
char prv = 0;
int cnt = 0;
for (int i = 0; i < n; i++) {
if (s[i] != prv) {
prv = s[i];
cnt = 0;
}
cnt++;
ret += cnt;
}
return ret;
}
int main() {
scanf("%s", s);
n = strlen(s);
printf("%lld\n", count_three() + count_two('a', 'b') + count_two('a', 'c') + count_two('b', 'c') + count_one());
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | #include <bits/stdc++.h> using namespace std; using ll = long long; char s[1333333]; int n; template<> struct std::hash<pair<int, int>> { size_t operator()(pair<int, int> p) const { return p.first * 12345 ^ p.second; } }; ll count_three() { ll ret = 0; unordered_map<pair<int, int>, int> m; pair<int, int> p{0, 0}; m[p]++; for (int i = 0; i < n; i++) { if (s[i] == 'a') { p.first--; p.second--; } else if (s[i] == 'b') { p.first++; } else { p.second++; } ret += m[p]; m[p]++; } return ret; } ll count_two(char a, char b) { ll ret = 0; unordered_map<int, int> m; m[0]++; int p = 0; for (int i = 0; i < n; i++) { if (s[i] == a) { p++; } else if (s[i] == b) { p--; } else { m.clear(); p = 0; } ret += m[p]; m[p]++; } return ret; } ll count_one() { ll ret = 0; char prv = 0; int cnt = 0; for (int i = 0; i < n; i++) { if (s[i] != prv) { prv = s[i]; cnt = 0; } cnt++; ret += cnt; } return ret; } int main() { scanf("%s", s); n = strlen(s); printf("%lld\n", count_three() + count_two('a', 'b') + count_two('a', 'c') + count_two('b', 'c') + count_one()); } |
English