#include <cstdio>
#include <set>
#include <vector>
#include <map>
#include <string>
const long long MOD = 998244353;
using namespace std;
using count_type = long long;
long long mod(long long x) { return (x % MOD + MOD) % MOD; }
count_type solve_brute(vector<char> &word) {
string word_string(word.begin(), word.end());
multiset<string> seen_subsequences = {""};
map<string, count_type> counts;
count_type result = 0;
// printf("%s\n", word_string.c_str());
for (int i=0; i<word_string.length(); i++) {
multiset<string> new_seen_subsequences(seen_subsequences);
for (auto subsequence : seen_subsequences) {
string new_subsequence = subsequence + word_string[i];
counts[new_subsequence]++;
new_seen_subsequences.insert(new_subsequence);
}
seen_subsequences = new_seen_subsequences;
}
for (auto [k, count] : counts) {
// printf("%s %lld\n", k.c_str(), count);
if (count > 1) {
result++;
}
}
// puts("------");
return result;
}
count_type solve(vector<char> &word) {
vector<count_type> dp1(word.size() + 1);
vector<count_type> dp2(word.size() + 1);
vector<int> last_index('z', -1);
vector<int> count('z', 0);
dp1[0] = 1;
for (int i = 1; i <= word.size(); i++) {
char character = word[i - 1];
int last_char_index = last_index[character];
dp1[i] = mod(dp1[i - 1] * 2);
dp2[i] = mod(dp2[i - 1] * 2);
if (last_char_index > 0) {
dp1[i] = mod(dp1[i] - dp1[last_char_index - 1]);
dp2[i] = mod(dp2[i] - dp2[last_char_index - 1]);
if (count[character] == 1) {
printf("i: %d lci: %d dp: %lld dp2: %lld\n", i, last_char_index,
dp1[last_char_index], dp2[i]);
dp2[i] = mod(dp1[last_char_index]);
// dp2[i] = mod(dp1[last_char_index] - (dp2[i] + 1));
}
}
last_index[character] = i;
count[character] += 1;
}
for (int i=0; i<dp1.size(); i++) {
printf("%lld ", dp1[i]);
}
puts("");
for (int i=0; i<dp2.size(); i++) {
printf("%lld ", dp2[i]);
}
puts("");
return dp2[word.size()];
}
void process_query(int n, vector<char> &word) {
int p;
char z;
scanf("%d %c", &p, &z);
word[p - 1] = z;
printf("%lld\n", solve_brute(word));
}
int main() {
int n, q;
scanf("%d %d", &n, &q);
vector<char> word(n);
for (int i = 0; i < n; i++) {
scanf("\n%c", &word[i]);
}
printf("%lld\n", solve_brute(word));
for (int i = 0; i < q; i++) {
process_query(n, word);
}
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | #include <cstdio> #include <set> #include <vector> #include <map> #include <string> const long long MOD = 998244353; using namespace std; using count_type = long long; long long mod(long long x) { return (x % MOD + MOD) % MOD; } count_type solve_brute(vector<char> &word) { string word_string(word.begin(), word.end()); multiset<string> seen_subsequences = {""}; map<string, count_type> counts; count_type result = 0; // printf("%s\n", word_string.c_str()); for (int i=0; i<word_string.length(); i++) { multiset<string> new_seen_subsequences(seen_subsequences); for (auto subsequence : seen_subsequences) { string new_subsequence = subsequence + word_string[i]; counts[new_subsequence]++; new_seen_subsequences.insert(new_subsequence); } seen_subsequences = new_seen_subsequences; } for (auto [k, count] : counts) { // printf("%s %lld\n", k.c_str(), count); if (count > 1) { result++; } } // puts("------"); return result; } count_type solve(vector<char> &word) { vector<count_type> dp1(word.size() + 1); vector<count_type> dp2(word.size() + 1); vector<int> last_index('z', -1); vector<int> count('z', 0); dp1[0] = 1; for (int i = 1; i <= word.size(); i++) { char character = word[i - 1]; int last_char_index = last_index[character]; dp1[i] = mod(dp1[i - 1] * 2); dp2[i] = mod(dp2[i - 1] * 2); if (last_char_index > 0) { dp1[i] = mod(dp1[i] - dp1[last_char_index - 1]); dp2[i] = mod(dp2[i] - dp2[last_char_index - 1]); if (count[character] == 1) { printf("i: %d lci: %d dp: %lld dp2: %lld\n", i, last_char_index, dp1[last_char_index], dp2[i]); dp2[i] = mod(dp1[last_char_index]); // dp2[i] = mod(dp1[last_char_index] - (dp2[i] + 1)); } } last_index[character] = i; count[character] += 1; } for (int i=0; i<dp1.size(); i++) { printf("%lld ", dp1[i]); } puts(""); for (int i=0; i<dp2.size(); i++) { printf("%lld ", dp2[i]); } puts(""); return dp2[word.size()]; } void process_query(int n, vector<char> &word) { int p; char z; scanf("%d %c", &p, &z); word[p - 1] = z; printf("%lld\n", solve_brute(word)); } int main() { int n, q; scanf("%d %d", &n, &q); vector<char> word(n); for (int i = 0; i < n; i++) { scanf("\n%c", &word[i]); } printf("%lld\n", solve_brute(word)); for (int i = 0; i < q; i++) { process_query(n, word); } } |
English