#include <bits/stdc++.h> using namespace std; using ll = long long; char s[1333333]; int n; template<> struct std::hash<pair<int, int>> { size_t operator()(pair<int, int> p) const { return p.first * 12345 ^ p.second; } }; ll count_three() { ll ret = 0; unordered_map<pair<int, int>, int> m; pair<int, int> p{0, 0}; m[p]++; for (int i = 0; i < n; i++) { if (s[i] == 'a') { p.first--; p.second--; } else if (s[i] == 'b') { p.first++; } else { p.second++; } ret += m[p]; m[p]++; } return ret; } ll count_two(char a, char b) { ll ret = 0; unordered_map<int, int> m; m[0]++; int p = 0; for (int i = 0; i < n; i++) { if (s[i] == a) { p++; } else if (s[i] == b) { p--; } else { m.clear(); p = 0; } ret += m[p]; m[p]++; } return ret; } ll count_one() { ll ret = 0; char prv = 0; int cnt = 0; for (int i = 0; i < n; i++) { if (s[i] != prv) { prv = s[i]; cnt = 0; } cnt++; ret += cnt; } return ret; } int main() { scanf("%s", s); n = strlen(s); printf("%lld\n", count_three() + count_two('a', 'b') + count_two('a', 'c') + count_two('b', 'c') + count_one()); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | #include <bits/stdc++.h> using namespace std; using ll = long long; char s[1333333]; int n; template<> struct std::hash<pair<int, int>> { size_t operator()(pair<int, int> p) const { return p.first * 12345 ^ p.second; } }; ll count_three() { ll ret = 0; unordered_map<pair<int, int>, int> m; pair<int, int> p{0, 0}; m[p]++; for (int i = 0; i < n; i++) { if (s[i] == 'a') { p.first--; p.second--; } else if (s[i] == 'b') { p.first++; } else { p.second++; } ret += m[p]; m[p]++; } return ret; } ll count_two(char a, char b) { ll ret = 0; unordered_map<int, int> m; m[0]++; int p = 0; for (int i = 0; i < n; i++) { if (s[i] == a) { p++; } else if (s[i] == b) { p--; } else { m.clear(); p = 0; } ret += m[p]; m[p]++; } return ret; } ll count_one() { ll ret = 0; char prv = 0; int cnt = 0; for (int i = 0; i < n; i++) { if (s[i] != prv) { prv = s[i]; cnt = 0; } cnt++; ret += cnt; } return ret; } int main() { scanf("%s", s); n = strlen(s); printf("%lld\n", count_three() + count_two('a', 'b') + count_two('a', 'c') + count_two('b', 'c') + count_one()); } |