#include <bits/stdc++.h>
using namespace std;
#define st first
#define nd second
using ll = long long;
using pii = pair<int, int>;
const int N = 3e5 + 7;
char s[N];
int n;
ll count1(char c) {
ll res = 0;
int w = 0;
for (int i = 0; i < n; ++i)
if (s[i] == c)
res += ++w;
else
w = 0;
return res;
}
ll count2(char c1, char c2) {
ll res = 0;
map<int, int> mp {{0, 1}};
int d = 0;
for (int i = 0; i < n; ++i) {
if (s[i] == c1)
res += mp[++d]++;
else if (s[i] == c2)
res += mp[--d]++;
else {
mp.clear();
mp[0] = 1;
d = 0;
}
}
return res;
}
ll count3() {
ll res = 0;
map<pii, int> mp {{{0, 0}, 1}};
pii d = {};
for (int i = 0; i < n; ++i) {
if (s[i] == 'a')
--d.st;
else if (s[i] == 'b')
++d.st, --d.nd;
else
++d.nd;
res += mp[d]++;
}
return res;
}
int main() {
scanf("%s", s);
n = strlen(s);
ll res = count1('a') + count1('b') + count1('c');
res += count2('a', 'b') + count2('a', 'c') + count2('b', 'c');
res += count3();
printf("%lld\n", res);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | #include <bits/stdc++.h> using namespace std; #define st first #define nd second using ll = long long; using pii = pair<int, int>; const int N = 3e5 + 7; char s[N]; int n; ll count1(char c) { ll res = 0; int w = 0; for (int i = 0; i < n; ++i) if (s[i] == c) res += ++w; else w = 0; return res; } ll count2(char c1, char c2) { ll res = 0; map<int, int> mp {{0, 1}}; int d = 0; for (int i = 0; i < n; ++i) { if (s[i] == c1) res += mp[++d]++; else if (s[i] == c2) res += mp[--d]++; else { mp.clear(); mp[0] = 1; d = 0; } } return res; } ll count3() { ll res = 0; map<pii, int> mp {{{0, 0}, 1}}; pii d = {}; for (int i = 0; i < n; ++i) { if (s[i] == 'a') --d.st; else if (s[i] == 'b') ++d.st, --d.nd; else ++d.nd; res += mp[d]++; } return res; } int main() { scanf("%s", s); n = strlen(s); ll res = count1('a') + count1('b') + count1('c'); res += count2('a', 'b') + count2('a', 'c') + count2('b', 'c'); res += count3(); printf("%lld\n", res); return 0; } |
English