#include <bits/stdc++.h> #define ll long long #define mp make_pair #define fi first #define se second #define pb push_back #define vi vector<int> #define pi pair<int, int> #define mod 998244353 template<typename T> bool chkmin(T &a, T b){return (b < a) ? a = b, 1 : 0;} template<typename T> bool chkmax(T &a, T b){return (b > a) ? a = b, 1 : 0;} ll ksm(ll a, ll b) {if (b == 0) return 1; ll ns = ksm(a, b >> 1); ns = ns * ns % mod; if (b & 1) ns = ns * a % mod; return ns;} using namespace std; int n; const int maxn = 500005; char inp[maxn]; map<pi, int> h, g[3], u[3]; int s[3]; int main() { scanf("%s", inp + 1); int n = strlen(inp + 1); ll ans = 0; for (int i = 0; i <= n; i++) { if (i) s[inp[i] - 'a'] += 1; for (int i = 0; i < 3; i++) { int f = s[(i + 1) % 3], m = s[(i + 2) % 3]; pi ca = mp(f - m, s[i]); ans += g[i][ca], g[i][ca] += 1; ca = mp(f, m); ans += u[i][ca], u[i][ca] += 1; } pi cur = mp(s[0] - s[1], s[0] - s[2]); ans += h[cur], h[cur] += 1; } cout << ans << endl; return (0-0); // <3 cxr }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | #include <bits/stdc++.h> #define ll long long #define mp make_pair #define fi first #define se second #define pb push_back #define vi vector<int> #define pi pair<int, int> #define mod 998244353 template<typename T> bool chkmin(T &a, T b){return (b < a) ? a = b, 1 : 0;} template<typename T> bool chkmax(T &a, T b){return (b > a) ? a = b, 1 : 0;} ll ksm(ll a, ll b) {if (b == 0) return 1; ll ns = ksm(a, b >> 1); ns = ns * ns % mod; if (b & 1) ns = ns * a % mod; return ns;} using namespace std; int n; const int maxn = 500005; char inp[maxn]; map<pi, int> h, g[3], u[3]; int s[3]; int main() { scanf("%s", inp + 1); int n = strlen(inp + 1); ll ans = 0; for (int i = 0; i <= n; i++) { if (i) s[inp[i] - 'a'] += 1; for (int i = 0; i < 3; i++) { int f = s[(i + 1) % 3], m = s[(i + 2) % 3]; pi ca = mp(f - m, s[i]); ans += g[i][ca], g[i][ca] += 1; ca = mp(f, m); ans += u[i][ca], u[i][ca] += 1; } pi cur = mp(s[0] - s[1], s[0] - s[2]); ans += h[cur], h[cur] += 1; } cout << ans << endl; return (0-0); // <3 cxr } |