#include <bits/stdc++.h> #define PB push_back #define ST first #define ND second #define _ ios_base::sync_with_stdio(0); cin.tie(0); //mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count()); using namespace std; using ll = long long; using pi = pair<int,int>; using vi = vector<int>; const int nax = 300 * 1000 + 10, mod = 1e9 + 7, p = 31; int cnt[nax][3]; int n; string s; tuple<int,int,int> t[nax]; ll ans; void solve(int m) { int f = 0; for(int i = 0; i < 3; ++i) { if((1 << i) & m) { f = i; break; } } for(int i = 0; i <= n; ++i) { int h[3]; for(int j = 0; j < 3; ++j) { if((1 << j) & m) { h[j] = (cnt[i][j] - cnt[i][f]); } else { h[j] = cnt[i][j]; } } t[i] = {h[0], h[1], h[2]}; } sort(t, t + n + 1); int g = 1; for(int i = 1; i <= n; ++i) { if(t[i] != t[i - 1]) { ans += (ll)g * (g - 1) / 2; g = 1; } else g++; } ans += (ll)g * (g - 1) / 2; } int main() {_ cin >> s; n = (int)s.size(); for(int i = 1; i <= n; ++i) { for(int j = 0; j < 3; ++j) cnt[i][j] = cnt[i - 1][j]; cnt[i][s[i - 1] - 'a']++; } for(int mask = 1; mask < (1 << 3); ++mask) { solve(mask); } cout << ans; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | #include <bits/stdc++.h> #define PB push_back #define ST first #define ND second #define _ ios_base::sync_with_stdio(0); cin.tie(0); //mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count()); using namespace std; using ll = long long; using pi = pair<int,int>; using vi = vector<int>; const int nax = 300 * 1000 + 10, mod = 1e9 + 7, p = 31; int cnt[nax][3]; int n; string s; tuple<int,int,int> t[nax]; ll ans; void solve(int m) { int f = 0; for(int i = 0; i < 3; ++i) { if((1 << i) & m) { f = i; break; } } for(int i = 0; i <= n; ++i) { int h[3]; for(int j = 0; j < 3; ++j) { if((1 << j) & m) { h[j] = (cnt[i][j] - cnt[i][f]); } else { h[j] = cnt[i][j]; } } t[i] = {h[0], h[1], h[2]}; } sort(t, t + n + 1); int g = 1; for(int i = 1; i <= n; ++i) { if(t[i] != t[i - 1]) { ans += (ll)g * (g - 1) / 2; g = 1; } else g++; } ans += (ll)g * (g - 1) / 2; } int main() {_ cin >> s; n = (int)s.size(); for(int i = 1; i <= n; ++i) { for(int j = 0; j < 3; ++j) cnt[i][j] = cnt[i - 1][j]; cnt[i][s[i - 1] - 'a']++; } for(int mask = 1; mask < (1 << 3); ++mask) { solve(mask); } cout << ans; } |