#include <bits/stdc++.h> using namespace std; const int MX=500500,md=1000000007; int n,i,j,x,res,rt[2],lft[2][MX],rgh[2][MX],le[2][MX],ri[2][MX]; void dfs(int i, int j, int L, int R) { lft[j][i]=L; rgh[j][i]=R; if (le[j][i]>0) dfs(le[j][i],j,L,i-1); if (ri[j][i]>0) dfs(ri[j][i],j,i+1,R); } void lineup(int i, bool lftc) { if (i<0) return; //printf(" before lineup %d (lftc=%d) = %d\n",i,lftc,res); if (le[0][i]>0) { lineup(le[0][i],true); if (!lftc) res+=i-lft[0][i]; } if (ri[0][i]>0) { lineup(ri[0][i],false); if (lftc) res+=rgh[0][i]-i; } //printf(" after lineup %d (lftc=%d) = %d\n",i,lftc,res); if (res>=md) res-=md; } int fnd(int i, int w, bool lftc) { if (i<0) return i; if (lftc) { if (ri[0][i]>0) { lineup(ri[0][i],false); res+=rgh[0][i]-i; //? //printf("+= %d - %d\n",rgh[0][i],i); if (res>=md) res-=md; } return (i>w)?fnd(le[0][i],w,true):le[0][i]; } if (le[0][i]>0) { lineup(le[0][i],true); res+=i-lft[0][i]; //? //printf("+= %d - %d ~\n",i,lft[0][i]); if (res>=md) res-=md; } return (i<w)?fnd(ri[0][i],w,false):ri[0][i]; } void solve(int pa, int pb, int L, int R) { /*if (pa<0 && pa==pb) return; if ((pa==-1 && pb==-2) || (pb==-1 && pa==-2)) { res+=R-L; if (res>=md) res-=md; return; }*/ //printf("solve %d %d %d..%d\n",pa,pb,L,R); if (pa==-1 || (pa>0 && rgh[0][pa]<pb)) { //puts(" case 1"); res+=R-pb; if (res>=md) res-=md; //printf(" res=%d\n",res); if (L<pb) solve(pa,le[1][pb],L,pb-1); if (R>pb) solve(-2,ri[1][pb],pb+1,R); return; } if (pa==-2 || (pa>0 && lft[0][pa]>pb)) { //puts(" case 2"); res+=pb-L; if (res>=md) res-=md; //printf(" res=%d\n",res); if (L<pb) solve(-1,le[1][pb],L,pb-1); if (R>pb) solve(pa,ri[1][pb],pb+1,R); return; } int lrt=le[0][pa]; int rrt=ri[0][pa]; //printf(" res=%d\n",res); if (L<lft[0][pa]) { //puts(" case 3a"); lrt=-1; rrt=fnd(pa,pb,false); res+=pb-L; } else if (R>rgh[0][pa]) { //puts(" case 4a"); lrt=fnd(pa,pb,true); rrt=-2; res+=R-pb; } else if (pa<pb) { //puts(" case 3"); rrt=fnd(rrt,pb,false); res+=pb-pa; } else if (pa>pb) { //puts(" case 4"); lrt=fnd(lrt,pb,true); res+=pa-pb; } if (res>=md) res-=md; //printf(" res=%d\n",res); if (L<pb) solve(lrt,le[1][pb],L,pb-1); if (R>pb) solve(rrt,ri[1][pb],pb+1,R); } int main() { scanf("%d",&n); for (j=0; j<2; j++) { for (i=1; i<=n; i++) { le[j][i]=-1; ri[j][i]=-2; } for (i=1; i<=n; i++) { scanf("%d",&x); if (x==-1) rt[j]=i; else if (x<i) ri[j][x]=i; else if (x>i) le[j][x]=i; } dfs(rt[j],j,1,n); } solve(rt[0],rt[1],1,n); printf("%d\n",res); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | #include <bits/stdc++.h> using namespace std; const int MX=500500,md=1000000007; int n,i,j,x,res,rt[2],lft[2][MX],rgh[2][MX],le[2][MX],ri[2][MX]; void dfs(int i, int j, int L, int R) { lft[j][i]=L; rgh[j][i]=R; if (le[j][i]>0) dfs(le[j][i],j,L,i-1); if (ri[j][i]>0) dfs(ri[j][i],j,i+1,R); } void lineup(int i, bool lftc) { if (i<0) return; //printf(" before lineup %d (lftc=%d) = %d\n",i,lftc,res); if (le[0][i]>0) { lineup(le[0][i],true); if (!lftc) res+=i-lft[0][i]; } if (ri[0][i]>0) { lineup(ri[0][i],false); if (lftc) res+=rgh[0][i]-i; } //printf(" after lineup %d (lftc=%d) = %d\n",i,lftc,res); if (res>=md) res-=md; } int fnd(int i, int w, bool lftc) { if (i<0) return i; if (lftc) { if (ri[0][i]>0) { lineup(ri[0][i],false); res+=rgh[0][i]-i; //? //printf("+= %d - %d\n",rgh[0][i],i); if (res>=md) res-=md; } return (i>w)?fnd(le[0][i],w,true):le[0][i]; } if (le[0][i]>0) { lineup(le[0][i],true); res+=i-lft[0][i]; //? //printf("+= %d - %d ~\n",i,lft[0][i]); if (res>=md) res-=md; } return (i<w)?fnd(ri[0][i],w,false):ri[0][i]; } void solve(int pa, int pb, int L, int R) { /*if (pa<0 && pa==pb) return; if ((pa==-1 && pb==-2) || (pb==-1 && pa==-2)) { res+=R-L; if (res>=md) res-=md; return; }*/ //printf("solve %d %d %d..%d\n",pa,pb,L,R); if (pa==-1 || (pa>0 && rgh[0][pa]<pb)) { //puts(" case 1"); res+=R-pb; if (res>=md) res-=md; //printf(" res=%d\n",res); if (L<pb) solve(pa,le[1][pb],L,pb-1); if (R>pb) solve(-2,ri[1][pb],pb+1,R); return; } if (pa==-2 || (pa>0 && lft[0][pa]>pb)) { //puts(" case 2"); res+=pb-L; if (res>=md) res-=md; //printf(" res=%d\n",res); if (L<pb) solve(-1,le[1][pb],L,pb-1); if (R>pb) solve(pa,ri[1][pb],pb+1,R); return; } int lrt=le[0][pa]; int rrt=ri[0][pa]; //printf(" res=%d\n",res); if (L<lft[0][pa]) { //puts(" case 3a"); lrt=-1; rrt=fnd(pa,pb,false); res+=pb-L; } else if (R>rgh[0][pa]) { //puts(" case 4a"); lrt=fnd(pa,pb,true); rrt=-2; res+=R-pb; } else if (pa<pb) { //puts(" case 3"); rrt=fnd(rrt,pb,false); res+=pb-pa; } else if (pa>pb) { //puts(" case 4"); lrt=fnd(lrt,pb,true); res+=pa-pb; } if (res>=md) res-=md; //printf(" res=%d\n",res); if (L<pb) solve(lrt,le[1][pb],L,pb-1); if (R>pb) solve(rrt,ri[1][pb],pb+1,R); } int main() { scanf("%d",&n); for (j=0; j<2; j++) { for (i=1; i<=n; i++) { le[j][i]=-1; ri[j][i]=-2; } for (i=1; i<=n; i++) { scanf("%d",&x); if (x==-1) rt[j]=i; else if (x<i) ri[j][x]=i; else if (x>i) le[j][x]=i; } dfs(rt[j],j,1,n); } solve(rt[0],rt[1],1,n); printf("%d\n",res); return 0; } |