#include <bits/stdc++.h> using namespace std; const int MX=100100; struct Node { long long mn,mx,add; } v[270270]; int n,m,cnt,tot,i,a[MX],b[MX],k[MX]; long long le,ri,mid,mn,s[MX]; char r[MX],sr[MX],ch,oth; bool cmp(int i, int j) { if (b[i]!=b[j]) return b[i]<b[j]; return i>j; } void init(int i, int L, int R) { v[i].add=0; if (L==R) { v[i].mn=v[i].mx=s[R]; r[R]=sr[R]; return; } int mid=(L+R)/2; init(i*2,L,mid); init(i*2+1,mid+1,R); v[i].mn=min(v[i*2].mn,v[i*2+1].mn)+v[i].add; v[i].mx=max(v[i*2].mx,v[i*2+1].mx)+v[i].add; } long long fmin(int i, int L, int R, int rgh) { if (R<=rgh) return v[i].mn; int mid=(L+R)/2; if (rgh<=mid) return fmin(i*2,L,mid,rgh)+v[i].add; return min(v[i*2].mn,fmin(i*2+1,mid+1,R,rgh))+v[i].add; } long long fmax(int i, int L, int R, int lft) { if (L>=lft) return v[i].mx; int mid=(L+R)/2; if (lft>mid) return fmax(i*2+1,mid+1,R,lft)+v[i].add; return max(v[i*2+1].mx,fmax(i*2,L,mid,lft))+v[i].add; } void upd(int i, int L, int R, int pos) { if (L>=pos) { v[i].mn+=b[pos]; v[i].mx+=b[pos]; v[i].add+=b[pos]; return; } int mid=(L+R)/2; if (pos<=mid) upd(i*2,L,mid,pos); upd(i*2+1,mid+1,R,pos); v[i].mn=min(v[i*2].mn,v[i*2+1].mn)+v[i].add; v[i].mx=max(v[i*2].mx,v[i*2+1].mx)+v[i].add; } void solve(long long lim) { //printf("solve %lld\n",lim); init(1,0,n); for (cnt=i=0; i<tot; i++) { int pos=k[i]; //printf("pos=%d | %lld vs %lld\n",pos,fmin(1,0,n,pos-1),fmax(1,0,n,pos)); if (fmin(1,0,n,pos-1)+lim>=fmax(1,0,n,pos)+b[pos]) { //puts("OK"); r[pos]=oth; if (++cnt==m) break; upd(1,0,n,pos); } } } int main() { scanf("%d%d",&n,&m); for (i=1; i<=n; i++) scanf("%d",&a[i]); for (i=1; i<=n; i++) { scanf("%d",&b[i]); if (b[i]<a[i]) { swap(a[i],b[i]); r[i]='B'; } else { r[i]='A'; ++cnt; } b[i]-=a[i]; s[i]=s[i-1]+a[i]; le=max(le,s[i]-mn); mn=min(mn,s[i]); ri+=b[i]; } if (m==cnt) { printf("%lld\n",le); puts(r+1); return 0; } if (m>cnt) { m-=cnt; ch='B'; oth='A'; } else { m=cnt-m; ch='A'; oth='B'; } for (i=1; i<=n; i++) if (r[i]==ch) { if (b[i]==0) { r[i]=oth; if (--m==0) break; } else k[tot++]=i; } if (m==0) { printf("%lld\n",le); puts(r+1); return 0; } for (i=1; i<=n; i++) sr[i]=r[i]; sort(k,k+tot,cmp); mid=le; ri+=le; //printf("%lld %lld | %lld\n",le,m,ri); //puts(r+1); //for (i=0; i<=n; i++) printf("%lld[%d] ",s[i],b[i]); puts("s"); while (le<ri) { mid=(le+ri)/2; solve(mid); if (cnt==m) ri=mid; else le=mid+1; } printf("%lld\n",ri); if (mid!=ri) solve(ri); puts(r+1); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | #include <bits/stdc++.h> using namespace std; const int MX=100100; struct Node { long long mn,mx,add; } v[270270]; int n,m,cnt,tot,i,a[MX],b[MX],k[MX]; long long le,ri,mid,mn,s[MX]; char r[MX],sr[MX],ch,oth; bool cmp(int i, int j) { if (b[i]!=b[j]) return b[i]<b[j]; return i>j; } void init(int i, int L, int R) { v[i].add=0; if (L==R) { v[i].mn=v[i].mx=s[R]; r[R]=sr[R]; return; } int mid=(L+R)/2; init(i*2,L,mid); init(i*2+1,mid+1,R); v[i].mn=min(v[i*2].mn,v[i*2+1].mn)+v[i].add; v[i].mx=max(v[i*2].mx,v[i*2+1].mx)+v[i].add; } long long fmin(int i, int L, int R, int rgh) { if (R<=rgh) return v[i].mn; int mid=(L+R)/2; if (rgh<=mid) return fmin(i*2,L,mid,rgh)+v[i].add; return min(v[i*2].mn,fmin(i*2+1,mid+1,R,rgh))+v[i].add; } long long fmax(int i, int L, int R, int lft) { if (L>=lft) return v[i].mx; int mid=(L+R)/2; if (lft>mid) return fmax(i*2+1,mid+1,R,lft)+v[i].add; return max(v[i*2+1].mx,fmax(i*2,L,mid,lft))+v[i].add; } void upd(int i, int L, int R, int pos) { if (L>=pos) { v[i].mn+=b[pos]; v[i].mx+=b[pos]; v[i].add+=b[pos]; return; } int mid=(L+R)/2; if (pos<=mid) upd(i*2,L,mid,pos); upd(i*2+1,mid+1,R,pos); v[i].mn=min(v[i*2].mn,v[i*2+1].mn)+v[i].add; v[i].mx=max(v[i*2].mx,v[i*2+1].mx)+v[i].add; } void solve(long long lim) { //printf("solve %lld\n",lim); init(1,0,n); for (cnt=i=0; i<tot; i++) { int pos=k[i]; //printf("pos=%d | %lld vs %lld\n",pos,fmin(1,0,n,pos-1),fmax(1,0,n,pos)); if (fmin(1,0,n,pos-1)+lim>=fmax(1,0,n,pos)+b[pos]) { //puts("OK"); r[pos]=oth; if (++cnt==m) break; upd(1,0,n,pos); } } } int main() { scanf("%d%d",&n,&m); for (i=1; i<=n; i++) scanf("%d",&a[i]); for (i=1; i<=n; i++) { scanf("%d",&b[i]); if (b[i]<a[i]) { swap(a[i],b[i]); r[i]='B'; } else { r[i]='A'; ++cnt; } b[i]-=a[i]; s[i]=s[i-1]+a[i]; le=max(le,s[i]-mn); mn=min(mn,s[i]); ri+=b[i]; } if (m==cnt) { printf("%lld\n",le); puts(r+1); return 0; } if (m>cnt) { m-=cnt; ch='B'; oth='A'; } else { m=cnt-m; ch='A'; oth='B'; } for (i=1; i<=n; i++) if (r[i]==ch) { if (b[i]==0) { r[i]=oth; if (--m==0) break; } else k[tot++]=i; } if (m==0) { printf("%lld\n",le); puts(r+1); return 0; } for (i=1; i<=n; i++) sr[i]=r[i]; sort(k,k+tot,cmp); mid=le; ri+=le; //printf("%lld %lld | %lld\n",le,m,ri); //puts(r+1); //for (i=0; i<=n; i++) printf("%lld[%d] ",s[i],b[i]); puts("s"); while (le<ri) { mid=(le+ri)/2; solve(mid); if (cnt==m) ri=mid; else le=mid+1; } printf("%lld\n",ri); if (mid!=ri) solve(ri); puts(r+1); return 0; } |