#include <cstdio>
#include <cstring>
#include <unordered_map>
typedef long long int lli;
const int N = 300300;
int counts[2 * N];
struct Part {
lli sum;
int end;
};
struct Triple {
int aOverB;
int aOverC;
int bOverC;
bool operator==(const Triple &other) const { return (aOverB == other.aOverB && aOverC == other.aOverC && bOverC == other.bOverC); }
};
struct TripleHasher {
std::size_t operator()(const Triple &t) const {
size_t res = 17;
res = res * 31 + t.aOverB;
res = res * 31 + t.aOverC;
res = res * 31 + t.bOverC;
return res;
}
};
lli single(char *word) {
char current = '\0';
int inc = 0;
lli sum = 0;
for (int i = 0, letter; letter = word[i]; ++i) {
if (current != letter) {
inc = 0;
current = letter;
}
sum += ++inc;
}
return sum;
}
Part double_part(char *word, int start, char omit) {
lli sum = 0;
int offset = N;
int min = N;
int max = N;
int diff = 0;
char first = word[start];
int i = start;
counts[N] = 1;
for (char c = word[i]; c && c != omit; c = word[++i]) {
offset += ((c == first) ? +1 : -1);
sum += counts[offset]++;
if (offset < min) {
min = offset;
} else if (offset > max) {
max = offset;
}
}
memset(counts + min, 0, sizeof(int) * (max - min + 1));
return {.sum = sum, .end = i};
}
lli double_(char *word, char omit) {
lli sum = 0;
for (int i = 0; word[i];) {
if (word[i] != omit) {
Part part = double_part(word, i, omit);
i = part.end;
sum += part.sum;
} else {
++i;
}
}
return sum;
}
lli triple(char *word) {
lli sum = 0;
Triple current = {.aOverB = 0, .aOverC = 0, .bOverC = 0};
std::unordered_map<Triple, int, TripleHasher> hash;
hash[current] = 1;
for (int i = 0; word[i]; ++i) {
if (word[i] == 'a') {
++current.aOverB;
++current.aOverC;
} else if (word[i] == 'b') {
--current.aOverB;
++current.bOverC;
} else {
--current.aOverC;
--current.bOverC;
}
auto search = hash.find(current);
int inc = search == hash.end() ? 0 : search->second;
hash[current] = 1 + inc;
sum += inc;
}
return sum;
}
int main() {
char word[N];
scanf("%s", word);
lli res = 0;
res += single(word);
res += double_(word, 'c');
res += double_(word, 'b');
res += double_(word, 'a');
res += triple(word);
printf("%lld\n", res);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | #include <cstdio> #include <cstring> #include <unordered_map> typedef long long int lli; const int N = 300300; int counts[2 * N]; struct Part { lli sum; int end; }; struct Triple { int aOverB; int aOverC; int bOverC; bool operator==(const Triple &other) const { return (aOverB == other.aOverB && aOverC == other.aOverC && bOverC == other.bOverC); } }; struct TripleHasher { std::size_t operator()(const Triple &t) const { size_t res = 17; res = res * 31 + t.aOverB; res = res * 31 + t.aOverC; res = res * 31 + t.bOverC; return res; } }; lli single(char *word) { char current = '\0'; int inc = 0; lli sum = 0; for (int i = 0, letter; letter = word[i]; ++i) { if (current != letter) { inc = 0; current = letter; } sum += ++inc; } return sum; } Part double_part(char *word, int start, char omit) { lli sum = 0; int offset = N; int min = N; int max = N; int diff = 0; char first = word[start]; int i = start; counts[N] = 1; for (char c = word[i]; c && c != omit; c = word[++i]) { offset += ((c == first) ? +1 : -1); sum += counts[offset]++; if (offset < min) { min = offset; } else if (offset > max) { max = offset; } } memset(counts + min, 0, sizeof(int) * (max - min + 1)); return {.sum = sum, .end = i}; } lli double_(char *word, char omit) { lli sum = 0; for (int i = 0; word[i];) { if (word[i] != omit) { Part part = double_part(word, i, omit); i = part.end; sum += part.sum; } else { ++i; } } return sum; } lli triple(char *word) { lli sum = 0; Triple current = {.aOverB = 0, .aOverC = 0, .bOverC = 0}; std::unordered_map<Triple, int, TripleHasher> hash; hash[current] = 1; for (int i = 0; word[i]; ++i) { if (word[i] == 'a') { ++current.aOverB; ++current.aOverC; } else if (word[i] == 'b') { --current.aOverB; ++current.bOverC; } else { --current.aOverC; --current.bOverC; } auto search = hash.find(current); int inc = search == hash.end() ? 0 : search->second; hash[current] = 1 + inc; sum += inc; } return sum; } int main() { char word[N]; scanf("%s", word); lli res = 0; res += single(word); res += double_(word, 'c'); res += double_(word, 'b'); res += double_(word, 'a'); res += triple(word); printf("%lld\n", res); return 0; } |
English