#include <bits/stdc++.h>
using namespace std;
const int MX=67067,md=998244353;
struct Node {
int fi[6],lst[6],cnt[6][6],f[6][6];
} b[MX*2];
int n,m,q,qpos[MX];
char s[MX],qs[MX];
void upd(int& x, int v) {
if ((x+=v)>=md) x-=md;
}
void recalc(int i) {
for (int j=0; j<m; j++) for (int k=0; k<m; k++) {
b[i].cnt[j][k]=0;
b[i].f[j][k]=0;
}
for (int j=0; j<m; j++) {
b[i].fi[j]=min(b[i*2].fi[j],b[i*2+1].fi[j]);
b[i].lst[j]=max(b[i*2].lst[j],b[i*2+1].lst[j]);
for (int k=0; k<m; k++) {
int rj=b[i*2].lst[j];
int lk=b[i*2+1].fi[k];
if (rj!=-1 && b[i*2].lst[k]!=-1) {
//if (i==1 && j==0 && k==1) printf("UPD-1-C <- %d\n",b[i*2].cnt[j][k]);
upd(b[i].cnt[j][k],b[i*2].cnt[j][k]);
if (lk==n+1) {
upd(b[i].f[j][k],b[i*2].f[j][k]);
//if (j==0 && k==1) printf("UPD-1-F <- %d\n",b[i*2].f[j][k]);
}
}
if (rj==-1 && lk!=n+1 && b[i*2+1].fi[j]!=-1) {
upd(b[i].cnt[j][k],b[i*2+1].cnt[j][k]);
upd(b[i].f[j][k],b[i*2+1].f[j][k]);
/*if (i==2 && j==0 && k==0) {
printf("UPD-2-C <- %d\n",b[i*2+1].cnt[j][k]);
printf("UPD-2-F <- %d\n",b[i*2+1].f[j][k]);
}*/
}
if (rj!=-1 && lk!=n+1 && b[i*2].lst[k]<=rj) for (int x=0; x<m; x++) {
long long lft=b[i*2].f[x][j];
//if (i==2 && j==0) printf("lft=%lld at %d\n",lft,x);
if (lft) for (int y=0; y<m; y++) {
long long rgh=b[i*2+1].cnt[k][y];
if (rgh) {
b[i].cnt[x][y]=(b[i].cnt[x][y]+lft*rgh)%md;
//if (i==2 && x==0 && y==0) printf("(%d %d)C += %lld * %lld\n",j,k,lft,rgh);
}
rgh=b[i*2+1].f[k][y];
if (rgh && b[i*2+1].fi[j]>=lk) {
b[i].f[x][y]=(b[i].f[x][y]+lft*rgh)%md;
//if (i==2 && x==0 && y==0) printf("(%d %d)F += %lld * %lld\n",j,k,lft,rgh);
}
}
}
}
}
}
void debug(int i, int L, int R) {
printf("%d [%d..%d]\n",i,L,R);
for (int j=0; j<m; j++) for (int k=0; k<m; k++) printf(" %d %d = %d %d\n",j,k,b[i].cnt[j][k],b[i].f[j][k]);
if (L==R) return;
int mid=(L+R)/2;
debug(i*2,L,mid);
debug(i*2+1,mid+1,R);
}
void solve() {
int tot_cnt=0,tot_f=0;
for (int j=0; j<m; j++) for (int k=0; k<m; k++) {
upd(tot_cnt,b[1].cnt[j][k]);
upd(tot_f,b[1].f[j][k]);
}
if ((tot_cnt-=tot_f)<0) tot_cnt+=md;
printf("%d\n",tot_cnt);
}
void init(int i, int L, int R) {
if (L==R) {
for (int j=0; j<m; j++) {
b[i].fi[j]=n+1;
b[i].lst[j]=-1;
for (int k=0; k<m; k++) {
b[i].cnt[j][k]=0;
b[i].f[j][k]=0;
}
}
int cur=s[R]-'a';
b[i].fi[cur]=b[i].lst[cur]=R;
b[i].cnt[cur][cur]=1;
b[i].f[cur][cur]=1;
return;
}
int mid=(L+R)/2;
init(i*2,L,mid);
init(i*2+1,mid+1,R);
recalc(i);
}
void upd(int i, int L, int R, int pos) {
if (L==R) {
for (int j=0; j<m; j++) {
b[i].fi[j]=n+1;
b[i].lst[j]=-1;
for (int k=0; k<m; k++) {
b[i].cnt[j][k]=0;
b[i].f[j][k]=0;
}
}
int cur=s[R]-'a';
b[i].fi[cur]=b[i].lst[cur]=R;
b[i].cnt[cur][cur]=1;
b[i].f[cur][cur]=1;
return;
}
int mid=(L+R)/2;
if (pos<=mid) upd(i*2,L,mid,pos); else upd(i*2+1,mid+1,R,pos);
recalc(i);
}
int main() {
scanf("%d%d",&n,&q);
scanf("%s",s+1);
char mxc='a';
for (int i=1; i<=n; i++) if (s[i]>mxc) mxc=s[i];
for (int i=0; i<q; i++) {
scanf("%d",&qpos[i]);
scanf("%s",qs+i);
if (qs[i]>mxc) mxc=qs[i];
}
m=mxc-'a'+1;
init(1,1,n);
//debug(1,1,n);
solve();
for (int i=0; i<q; i++) {
s[qpos[i]]=qs[i];
upd(1,1,n,qpos[i]);
solve();
}
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | #include <bits/stdc++.h> using namespace std; const int MX=67067,md=998244353; struct Node { int fi[6],lst[6],cnt[6][6],f[6][6]; } b[MX*2]; int n,m,q,qpos[MX]; char s[MX],qs[MX]; void upd(int& x, int v) { if ((x+=v)>=md) x-=md; } void recalc(int i) { for (int j=0; j<m; j++) for (int k=0; k<m; k++) { b[i].cnt[j][k]=0; b[i].f[j][k]=0; } for (int j=0; j<m; j++) { b[i].fi[j]=min(b[i*2].fi[j],b[i*2+1].fi[j]); b[i].lst[j]=max(b[i*2].lst[j],b[i*2+1].lst[j]); for (int k=0; k<m; k++) { int rj=b[i*2].lst[j]; int lk=b[i*2+1].fi[k]; if (rj!=-1 && b[i*2].lst[k]!=-1) { //if (i==1 && j==0 && k==1) printf("UPD-1-C <- %d\n",b[i*2].cnt[j][k]); upd(b[i].cnt[j][k],b[i*2].cnt[j][k]); if (lk==n+1) { upd(b[i].f[j][k],b[i*2].f[j][k]); //if (j==0 && k==1) printf("UPD-1-F <- %d\n",b[i*2].f[j][k]); } } if (rj==-1 && lk!=n+1 && b[i*2+1].fi[j]!=-1) { upd(b[i].cnt[j][k],b[i*2+1].cnt[j][k]); upd(b[i].f[j][k],b[i*2+1].f[j][k]); /*if (i==2 && j==0 && k==0) { printf("UPD-2-C <- %d\n",b[i*2+1].cnt[j][k]); printf("UPD-2-F <- %d\n",b[i*2+1].f[j][k]); }*/ } if (rj!=-1 && lk!=n+1 && b[i*2].lst[k]<=rj) for (int x=0; x<m; x++) { long long lft=b[i*2].f[x][j]; //if (i==2 && j==0) printf("lft=%lld at %d\n",lft,x); if (lft) for (int y=0; y<m; y++) { long long rgh=b[i*2+1].cnt[k][y]; if (rgh) { b[i].cnt[x][y]=(b[i].cnt[x][y]+lft*rgh)%md; //if (i==2 && x==0 && y==0) printf("(%d %d)C += %lld * %lld\n",j,k,lft,rgh); } rgh=b[i*2+1].f[k][y]; if (rgh && b[i*2+1].fi[j]>=lk) { b[i].f[x][y]=(b[i].f[x][y]+lft*rgh)%md; //if (i==2 && x==0 && y==0) printf("(%d %d)F += %lld * %lld\n",j,k,lft,rgh); } } } } } } void debug(int i, int L, int R) { printf("%d [%d..%d]\n",i,L,R); for (int j=0; j<m; j++) for (int k=0; k<m; k++) printf(" %d %d = %d %d\n",j,k,b[i].cnt[j][k],b[i].f[j][k]); if (L==R) return; int mid=(L+R)/2; debug(i*2,L,mid); debug(i*2+1,mid+1,R); } void solve() { int tot_cnt=0,tot_f=0; for (int j=0; j<m; j++) for (int k=0; k<m; k++) { upd(tot_cnt,b[1].cnt[j][k]); upd(tot_f,b[1].f[j][k]); } if ((tot_cnt-=tot_f)<0) tot_cnt+=md; printf("%d\n",tot_cnt); } void init(int i, int L, int R) { if (L==R) { for (int j=0; j<m; j++) { b[i].fi[j]=n+1; b[i].lst[j]=-1; for (int k=0; k<m; k++) { b[i].cnt[j][k]=0; b[i].f[j][k]=0; } } int cur=s[R]-'a'; b[i].fi[cur]=b[i].lst[cur]=R; b[i].cnt[cur][cur]=1; b[i].f[cur][cur]=1; return; } int mid=(L+R)/2; init(i*2,L,mid); init(i*2+1,mid+1,R); recalc(i); } void upd(int i, int L, int R, int pos) { if (L==R) { for (int j=0; j<m; j++) { b[i].fi[j]=n+1; b[i].lst[j]=-1; for (int k=0; k<m; k++) { b[i].cnt[j][k]=0; b[i].f[j][k]=0; } } int cur=s[R]-'a'; b[i].fi[cur]=b[i].lst[cur]=R; b[i].cnt[cur][cur]=1; b[i].f[cur][cur]=1; return; } int mid=(L+R)/2; if (pos<=mid) upd(i*2,L,mid,pos); else upd(i*2+1,mid+1,R,pos); recalc(i); } int main() { scanf("%d%d",&n,&q); scanf("%s",s+1); char mxc='a'; for (int i=1; i<=n; i++) if (s[i]>mxc) mxc=s[i]; for (int i=0; i<q; i++) { scanf("%d",&qpos[i]); scanf("%s",qs+i); if (qs[i]>mxc) mxc=qs[i]; } m=mxc-'a'+1; init(1,1,n); //debug(1,1,n); solve(); for (int i=0; i<q; i++) { s[qpos[i]]=qs[i]; upd(1,1,n,qpos[i]); solve(); } return 0; } |
English