#include<bits/stdc++.h> using namespace std; #define all(X) (X).begin(), (X).end() typedef long long ll; typedef pair<int,int> pii; int32_t main(){ ios::sync_with_stdio(false); cin.tie(0); string s; cin >> s; int n = s.size(); vector<vector<int> > prev(n,vector<int>(3)); for(int i=0;i<n;i++) { for(int j=0;j<3;j++) prev[i][j] = (i == 0 ? -1 : prev[i-1][j]); prev[i][s[i]-'a'] = i; if(i > 0 && s[i] != s[i-1]) prev[i][s[i-1]-'a'] = i-1; } vector<int> cnt(3); ll res = 0; vector<map<int,vector<int> >> ab(3); //bc, ac, ab map<pair<int,int>, vector<int> > abc; for(int i=0;i<3;i++) ab[i][0].push_back(-1); abc[{0,0}].push_back(-1); for(int i=0;i<n;i++) { cnt[s[i]-'a']++; vector<int> diffs = {cnt[1]-cnt[2], cnt[0]-cnt[2], cnt[0]-cnt[1]}; ab[2][diffs[2]].push_back(i); ab[1][diffs[1]].push_back(i); ab[0][diffs[0]].push_back(i); abc[make_pair(diffs[0], diffs[1])].push_back(i); vector<int> p = prev[i]; sort(p.begin(), p.end()); res += i-p[1]; if(p[1] == -1) continue; int t = 0; while(prev[i][t] != p[0]) t++; const auto&v = ab[t][diffs[t]]; res += v.end() - lower_bound(all(v), prev[i][t])-1; if(p[0] != -1) { const auto&v2 = abc[make_pair(diffs[0], diffs[1])]; res += lower_bound(all(v2), prev[i][t]+1) - v2.begin(); } } cout<<res<<"\n"; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | #include<bits/stdc++.h> using namespace std; #define all(X) (X).begin(), (X).end() typedef long long ll; typedef pair<int,int> pii; int32_t main(){ ios::sync_with_stdio(false); cin.tie(0); string s; cin >> s; int n = s.size(); vector<vector<int> > prev(n,vector<int>(3)); for(int i=0;i<n;i++) { for(int j=0;j<3;j++) prev[i][j] = (i == 0 ? -1 : prev[i-1][j]); prev[i][s[i]-'a'] = i; if(i > 0 && s[i] != s[i-1]) prev[i][s[i-1]-'a'] = i-1; } vector<int> cnt(3); ll res = 0; vector<map<int,vector<int> >> ab(3); //bc, ac, ab map<pair<int,int>, vector<int> > abc; for(int i=0;i<3;i++) ab[i][0].push_back(-1); abc[{0,0}].push_back(-1); for(int i=0;i<n;i++) { cnt[s[i]-'a']++; vector<int> diffs = {cnt[1]-cnt[2], cnt[0]-cnt[2], cnt[0]-cnt[1]}; ab[2][diffs[2]].push_back(i); ab[1][diffs[1]].push_back(i); ab[0][diffs[0]].push_back(i); abc[make_pair(diffs[0], diffs[1])].push_back(i); vector<int> p = prev[i]; sort(p.begin(), p.end()); res += i-p[1]; if(p[1] == -1) continue; int t = 0; while(prev[i][t] != p[0]) t++; const auto&v = ab[t][diffs[t]]; res += v.end() - lower_bound(all(v), prev[i][t])-1; if(p[0] != -1) { const auto&v2 = abc[make_pair(diffs[0], diffs[1])]; res += lower_bound(all(v2), prev[i][t]+1) - v2.begin(); } } cout<<res<<"\n"; } |