#include <cstdio> #include <set> #include <vector> #include <map> #include <string> const long long MOD = 998244353; using namespace std; using count_type = long long; long long mod(long long x) { return (x % MOD + MOD) % MOD; } count_type solve_brute(vector<char> &word) { string word_string(word.begin(), word.end()); multiset<string> seen_subsequences = {""}; map<string, count_type> counts; count_type result = 0; // printf("%s\n", word_string.c_str()); for (int i=0; i<word_string.length(); i++) { multiset<string> new_seen_subsequences(seen_subsequences); for (auto subsequence : seen_subsequences) { string new_subsequence = subsequence + word_string[i]; counts[new_subsequence]++; new_seen_subsequences.insert(new_subsequence); } seen_subsequences = new_seen_subsequences; } for (auto [k, count] : counts) { // printf("%s %lld\n", k.c_str(), count); if (count > 1) { result++; } } // puts("------"); return result; } count_type solve(vector<char> &word) { vector<count_type> dp1(word.size() + 1); vector<count_type> dp2(word.size() + 1); vector<int> last_index('z', -1); vector<int> count('z', 0); dp1[0] = 1; for (int i = 1; i <= word.size(); i++) { char character = word[i - 1]; int last_char_index = last_index[character]; dp1[i] = mod(dp1[i - 1] * 2); dp2[i] = mod(dp2[i - 1] * 2); if (last_char_index > 0) { dp1[i] = mod(dp1[i] - dp1[last_char_index - 1]); dp2[i] = mod(dp2[i] - dp2[last_char_index - 1]); if (count[character] == 1) { printf("i: %d lci: %d dp: %lld dp2: %lld\n", i, last_char_index, dp1[last_char_index], dp2[i]); dp2[i] = mod(dp1[last_char_index]); // dp2[i] = mod(dp1[last_char_index] - (dp2[i] + 1)); } } last_index[character] = i; count[character] += 1; } for (int i=0; i<dp1.size(); i++) { printf("%lld ", dp1[i]); } puts(""); for (int i=0; i<dp2.size(); i++) { printf("%lld ", dp2[i]); } puts(""); return dp2[word.size()]; } void process_query(int n, vector<char> &word) { int p; char z; scanf("%d %c", &p, &z); word[p - 1] = z; printf("%lld\n", solve_brute(word)); } int main() { int n, q; scanf("%d %d", &n, &q); vector<char> word(n); for (int i = 0; i < n; i++) { scanf("\n%c", &word[i]); } printf("%lld\n", solve_brute(word)); for (int i = 0; i < q; i++) { process_query(n, word); } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | #include <cstdio> #include <set> #include <vector> #include <map> #include <string> const long long MOD = 998244353; using namespace std; using count_type = long long; long long mod(long long x) { return (x % MOD + MOD) % MOD; } count_type solve_brute(vector<char> &word) { string word_string(word.begin(), word.end()); multiset<string> seen_subsequences = {""}; map<string, count_type> counts; count_type result = 0; // printf("%s\n", word_string.c_str()); for (int i=0; i<word_string.length(); i++) { multiset<string> new_seen_subsequences(seen_subsequences); for (auto subsequence : seen_subsequences) { string new_subsequence = subsequence + word_string[i]; counts[new_subsequence]++; new_seen_subsequences.insert(new_subsequence); } seen_subsequences = new_seen_subsequences; } for (auto [k, count] : counts) { // printf("%s %lld\n", k.c_str(), count); if (count > 1) { result++; } } // puts("------"); return result; } count_type solve(vector<char> &word) { vector<count_type> dp1(word.size() + 1); vector<count_type> dp2(word.size() + 1); vector<int> last_index('z', -1); vector<int> count('z', 0); dp1[0] = 1; for (int i = 1; i <= word.size(); i++) { char character = word[i - 1]; int last_char_index = last_index[character]; dp1[i] = mod(dp1[i - 1] * 2); dp2[i] = mod(dp2[i - 1] * 2); if (last_char_index > 0) { dp1[i] = mod(dp1[i] - dp1[last_char_index - 1]); dp2[i] = mod(dp2[i] - dp2[last_char_index - 1]); if (count[character] == 1) { printf("i: %d lci: %d dp: %lld dp2: %lld\n", i, last_char_index, dp1[last_char_index], dp2[i]); dp2[i] = mod(dp1[last_char_index]); // dp2[i] = mod(dp1[last_char_index] - (dp2[i] + 1)); } } last_index[character] = i; count[character] += 1; } for (int i=0; i<dp1.size(); i++) { printf("%lld ", dp1[i]); } puts(""); for (int i=0; i<dp2.size(); i++) { printf("%lld ", dp2[i]); } puts(""); return dp2[word.size()]; } void process_query(int n, vector<char> &word) { int p; char z; scanf("%d %c", &p, &z); word[p - 1] = z; printf("%lld\n", solve_brute(word)); } int main() { int n, q; scanf("%d %d", &n, &q); vector<char> word(n); for (int i = 0; i < n; i++) { scanf("\n%c", &word[i]); } printf("%lld\n", solve_brute(word)); for (int i = 0; i < q; i++) { process_query(n, word); } } |