#include <cstdio> #include <cstring> #include <algorithm> using namespace std; char s[300005]; pair<int, int> w[300005]; int n, licz[600010]; void dodaj(long long *wynik, char l, char a, char b) { int *wt = licz + 300005; wt[0] = 1; int t = 0, minz = 0, maxz = 0; for (int i = 1; i<=n; i++) { if (s[i] == a) { t++; *wynik += wt[t]; wt[t]++; if (t > maxz) maxz = t; } if (s[i] == b) { t--; *wynik += wt[t]; wt[t]++; if (t < minz) minz = t; } if (s[i] == l) { for (int j = minz; j <= maxz; j++) wt[j] = 0; minz = maxz = t = 0; wt[0]=1; } } for (int j = minz; j <= maxz; j++) wt[j] = 0; } int main() { scanf("%s", s+1); n = strlen(s+1); w[0] = make_pair(0, 0); int len = 0; long long wynik = 0; for (int i = 1; i<=n; i++) { w[i] = w[i-1]; switch (s[i]) { case 'a' : w[i].first++; break; case 'b' : w[i].second++; break; case 'c' : w[i].first--; w[i].second--; break; } if (i == 1 || s[i] == s[i-1]) len++; else len = 1; wynik += len; } //printf("%lld\n",wynik); sort(w, w+n+1); int ost = 0; for (int i=1; i<=n; i++) { if (w[i] != w[i-1]) { int len = i - ost; wynik += (long long) len * (len - 1) / 2; ost = i; } } n -= ost; wynik += (long long) (n+1) * n / 2; n+=ost; //printf("%lld\n",wynik); dodaj(&wynik, 'a', 'b', 'c'); dodaj(&wynik, 'b', 'c', 'a'); dodaj(&wynik, 'c', 'a', 'b'); printf("%lld\n", wynik); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | #include <cstdio> #include <cstring> #include <algorithm> using namespace std; char s[300005]; pair<int, int> w[300005]; int n, licz[600010]; void dodaj(long long *wynik, char l, char a, char b) { int *wt = licz + 300005; wt[0] = 1; int t = 0, minz = 0, maxz = 0; for (int i = 1; i<=n; i++) { if (s[i] == a) { t++; *wynik += wt[t]; wt[t]++; if (t > maxz) maxz = t; } if (s[i] == b) { t--; *wynik += wt[t]; wt[t]++; if (t < minz) minz = t; } if (s[i] == l) { for (int j = minz; j <= maxz; j++) wt[j] = 0; minz = maxz = t = 0; wt[0]=1; } } for (int j = minz; j <= maxz; j++) wt[j] = 0; } int main() { scanf("%s", s+1); n = strlen(s+1); w[0] = make_pair(0, 0); int len = 0; long long wynik = 0; for (int i = 1; i<=n; i++) { w[i] = w[i-1]; switch (s[i]) { case 'a' : w[i].first++; break; case 'b' : w[i].second++; break; case 'c' : w[i].first--; w[i].second--; break; } if (i == 1 || s[i] == s[i-1]) len++; else len = 1; wynik += len; } //printf("%lld\n",wynik); sort(w, w+n+1); int ost = 0; for (int i=1; i<=n; i++) { if (w[i] != w[i-1]) { int len = i - ost; wynik += (long long) len * (len - 1) / 2; ost = i; } } n -= ost; wynik += (long long) (n+1) * n / 2; n+=ost; //printf("%lld\n",wynik); dodaj(&wynik, 'a', 'b', 'c'); dodaj(&wynik, 'b', 'c', 'a'); dodaj(&wynik, 'c', 'a', 'b'); printf("%lld\n", wynik); } |