#include <bits/stdc++.h> #define rep(i, a, b) for (int i = a; i <= b; i++) #define per(i, a, b) for (int i = b; a <= i; i--) #define cat(x) cerr << #x << " = " << x << '\n'; using ll = long long; using namespace std; int n, cnt[3], last[3]; char s[300111]; ll res; map<int, vector<int>> mp[3]; map<pair<int, int>, vector<int>> abc; int main() { cin.tie(0)->sync_with_stdio(0); cin >> s + 1; n = strlen(s + 1); abc[make_pair(0, 0)].push_back(0); mp[0][0].push_back(0); mp[1][0].push_back(0); mp[2][0].push_back(0); for (int i = 1; i <= n; i++) { int id = s[i] - 'a'; cnt[id]++; last[id] = i; int x = min(last[(id + 1) % 3], last[(id + 2) % 3]); int y = max(last[(id + 1) % 3], last[(id + 2) % 3]); res += i - y; int z = i - 2 * cnt[id]; res += lower_bound(mp[id][z].begin(), mp[id][z].end(), y) - lower_bound(mp[id][z].begin(), mp[id][z].end(), x); pair<int, int> c = {cnt[0] - cnt[1], cnt[1] - cnt[2]}; res += lower_bound(abc[c].begin(), abc[c].end(), x) - abc[c].begin(); abc[c].push_back(i); for (int j = 0; j < 3; j++) { mp[j][i - 2 * cnt[j]].push_back(i); } } cout << res << "\n"; return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | #include <bits/stdc++.h> #define rep(i, a, b) for (int i = a; i <= b; i++) #define per(i, a, b) for (int i = b; a <= i; i--) #define cat(x) cerr << #x << " = " << x << '\n'; using ll = long long; using namespace std; int n, cnt[3], last[3]; char s[300111]; ll res; map<int, vector<int>> mp[3]; map<pair<int, int>, vector<int>> abc; int main() { cin.tie(0)->sync_with_stdio(0); cin >> s + 1; n = strlen(s + 1); abc[make_pair(0, 0)].push_back(0); mp[0][0].push_back(0); mp[1][0].push_back(0); mp[2][0].push_back(0); for (int i = 1; i <= n; i++) { int id = s[i] - 'a'; cnt[id]++; last[id] = i; int x = min(last[(id + 1) % 3], last[(id + 2) % 3]); int y = max(last[(id + 1) % 3], last[(id + 2) % 3]); res += i - y; int z = i - 2 * cnt[id]; res += lower_bound(mp[id][z].begin(), mp[id][z].end(), y) - lower_bound(mp[id][z].begin(), mp[id][z].end(), x); pair<int, int> c = {cnt[0] - cnt[1], cnt[1] - cnt[2]}; res += lower_bound(abc[c].begin(), abc[c].end(), x) - abc[c].begin(); abc[c].push_back(i); for (int j = 0; j < 3; j++) { mp[j][i - 2 * cnt[j]].push_back(i); } } cout << res << "\n"; return 0; } |