#include <cstdio> #include <cstring> #include <unordered_map> typedef long long int lli; const int N = 300300; int counts[2 * N]; struct Part { lli sum; int end; }; struct Triple { int aOverB; int aOverC; int bOverC; bool operator==(const Triple &other) const { return (aOverB == other.aOverB && aOverC == other.aOverC && bOverC == other.bOverC); } }; struct TripleHasher { std::size_t operator()(const Triple &t) const { size_t res = 17; res = res * 31 + t.aOverB; res = res * 31 + t.aOverC; res = res * 31 + t.bOverC; return res; } }; lli single(char *word) { char current = '\0'; int inc = 0; lli sum = 0; for (int i = 0, letter; letter = word[i]; ++i) { if (current != letter) { inc = 0; current = letter; } sum += ++inc; } return sum; } Part double_part(char *word, int start, char omit) { lli sum = 0; int offset = N; int min = N; int max = N; int diff = 0; char first = word[start]; int i = start; counts[N] = 1; for (char c = word[i]; c && c != omit; c = word[++i]) { offset += ((c == first) ? +1 : -1); sum += counts[offset]++; if (offset < min) { min = offset; } else if (offset > max) { max = offset; } } memset(counts + min, 0, sizeof(int) * (max - min + 1)); return {.sum = sum, .end = i}; } lli double_(char *word, char omit) { lli sum = 0; for (int i = 0; word[i];) { if (word[i] != omit) { Part part = double_part(word, i, omit); i = part.end; sum += part.sum; } else { ++i; } } return sum; } lli triple(char *word) { lli sum = 0; Triple current = {.aOverB = 0, .aOverC = 0, .bOverC = 0}; std::unordered_map<Triple, int, TripleHasher> hash; hash[current] = 1; for (int i = 0; word[i]; ++i) { if (word[i] == 'a') { ++current.aOverB; ++current.aOverC; } else if (word[i] == 'b') { --current.aOverB; ++current.bOverC; } else { --current.aOverC; --current.bOverC; } auto search = hash.find(current); int inc = search == hash.end() ? 0 : search->second; hash[current] = 1 + inc; sum += inc; } return sum; } int main() { char word[N]; scanf("%s", word); lli res = 0; res += single(word); res += double_(word, 'c'); res += double_(word, 'b'); res += double_(word, 'a'); res += triple(word); printf("%lld\n", res); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | #include <cstdio> #include <cstring> #include <unordered_map> typedef long long int lli; const int N = 300300; int counts[2 * N]; struct Part { lli sum; int end; }; struct Triple { int aOverB; int aOverC; int bOverC; bool operator==(const Triple &other) const { return (aOverB == other.aOverB && aOverC == other.aOverC && bOverC == other.bOverC); } }; struct TripleHasher { std::size_t operator()(const Triple &t) const { size_t res = 17; res = res * 31 + t.aOverB; res = res * 31 + t.aOverC; res = res * 31 + t.bOverC; return res; } }; lli single(char *word) { char current = '\0'; int inc = 0; lli sum = 0; for (int i = 0, letter; letter = word[i]; ++i) { if (current != letter) { inc = 0; current = letter; } sum += ++inc; } return sum; } Part double_part(char *word, int start, char omit) { lli sum = 0; int offset = N; int min = N; int max = N; int diff = 0; char first = word[start]; int i = start; counts[N] = 1; for (char c = word[i]; c && c != omit; c = word[++i]) { offset += ((c == first) ? +1 : -1); sum += counts[offset]++; if (offset < min) { min = offset; } else if (offset > max) { max = offset; } } memset(counts + min, 0, sizeof(int) * (max - min + 1)); return {.sum = sum, .end = i}; } lli double_(char *word, char omit) { lli sum = 0; for (int i = 0; word[i];) { if (word[i] != omit) { Part part = double_part(word, i, omit); i = part.end; sum += part.sum; } else { ++i; } } return sum; } lli triple(char *word) { lli sum = 0; Triple current = {.aOverB = 0, .aOverC = 0, .bOverC = 0}; std::unordered_map<Triple, int, TripleHasher> hash; hash[current] = 1; for (int i = 0; word[i]; ++i) { if (word[i] == 'a') { ++current.aOverB; ++current.aOverC; } else if (word[i] == 'b') { --current.aOverB; ++current.bOverC; } else { --current.aOverC; --current.bOverC; } auto search = hash.find(current); int inc = search == hash.end() ? 0 : search->second; hash[current] = 1 + inc; sum += inc; } return sum; } int main() { char word[N]; scanf("%s", word); lli res = 0; res += single(word); res += double_(word, 'c'); res += double_(word, 'b'); res += double_(word, 'a'); res += triple(word); printf("%lld\n", res); return 0; } |