#include <bits/stdc++.h> using namespace std; #define st first #define nd second using ll = long long; using pii = pair<int, int>; const int N = 3e5 + 7; char s[N]; int n; ll count1(char c) { ll res = 0; int w = 0; for (int i = 0; i < n; ++i) if (s[i] == c) res += ++w; else w = 0; return res; } ll count2(char c1, char c2) { ll res = 0; map<int, int> mp {{0, 1}}; int d = 0; for (int i = 0; i < n; ++i) { if (s[i] == c1) res += mp[++d]++; else if (s[i] == c2) res += mp[--d]++; else { mp.clear(); mp[0] = 1; d = 0; } } return res; } ll count3() { ll res = 0; map<pii, int> mp {{{0, 0}, 1}}; pii d = {}; for (int i = 0; i < n; ++i) { if (s[i] == 'a') --d.st; else if (s[i] == 'b') ++d.st, --d.nd; else ++d.nd; res += mp[d]++; } return res; } int main() { scanf("%s", s); n = strlen(s); ll res = count1('a') + count1('b') + count1('c'); res += count2('a', 'b') + count2('a', 'c') + count2('b', 'c'); res += count3(); printf("%lld\n", res); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | #include <bits/stdc++.h> using namespace std; #define st first #define nd second using ll = long long; using pii = pair<int, int>; const int N = 3e5 + 7; char s[N]; int n; ll count1(char c) { ll res = 0; int w = 0; for (int i = 0; i < n; ++i) if (s[i] == c) res += ++w; else w = 0; return res; } ll count2(char c1, char c2) { ll res = 0; map<int, int> mp {{0, 1}}; int d = 0; for (int i = 0; i < n; ++i) { if (s[i] == c1) res += mp[++d]++; else if (s[i] == c2) res += mp[--d]++; else { mp.clear(); mp[0] = 1; d = 0; } } return res; } ll count3() { ll res = 0; map<pii, int> mp {{{0, 0}, 1}}; pii d = {}; for (int i = 0; i < n; ++i) { if (s[i] == 'a') --d.st; else if (s[i] == 'b') ++d.st, --d.nd; else ++d.nd; res += mp[d]++; } return res; } int main() { scanf("%s", s); n = strlen(s); ll res = count1('a') + count1('b') + count1('c'); res += count2('a', 'b') + count2('a', 'c') + count2('b', 'c'); res += count3(); printf("%lld\n", res); return 0; } |