#include<bits/stdc++.h> using namespace std; int oj1[500009], lsyn1[500009], psyn1[500009], oj2[500009], lsyn2[500009], psyn2[500009], sz[500009], wyn = 0, n; const int M = 1e9 + 7; int cost(int u, int a, int poz, int oj) { int b = 0, c, d = 0; if(lsyn1[u] == 0 && psyn1[u] == 0) { sz[u] = 1; return 0; } if(u == a) { if(oj == 1) { if(psyn1[u] != 0) b = cost(psyn1[u], a, poz, 2); else sz[psyn1[u]] = 0; sz[u] = (sz[psyn1[u]] + 1) % M; return (b + sz[psyn1[u]]) % M; } else { if(lsyn1[u] != 0) b = cost(lsyn1[u], a, poz, 1); else sz[lsyn1[u]] = 0; sz[u] = (sz[lsyn1[u]] + 1) % M; return (b + sz[lsyn1[u]]) % M; } } if(lsyn1[u] != 0) b = cost(lsyn1[u], a, poz, 1); else sz[lsyn1[u]] = 0; if(psyn1[u] != 0) d = cost(psyn1[u], a, poz, 2); else sz[psyn1[u]] = 0; b = (b + d) % M; sz[u] = (sz[lsyn1[u]] + sz[psyn1[u]] + 1) % M; if(oj == 1) b = (b + sz[psyn1[u]]) % M; else b = (b + sz[lsyn1[u]]) % M; return b; } void solve(int u, int v) { int a, b; if(lsyn2[u] == 0 && psyn2[u] == 0) return; if(u < v) { a = v; while(a > u) a = lsyn1[a]; b = cost(lsyn1[v], a, 1, 1); wyn = (b + wyn) % M; b = (v - u) % M; wyn = (b + wyn) % M; for(int i = u; i > a; i--) { lsyn1[i] = i-1; if(i != u) psyn1[i] = 0; } psyn1[a] = 0; lsyn1[v] = 0; for(int i=u; i<v; i++) { if(i != u) lsyn1[i] = 0; psyn1[i] = i+1; } } else if(u > v) { a = v; while(a < u) a = psyn1[a]; b = cost(psyn1[v], a, 2, 2); wyn = (b + wyn) % M; b = (u - v) % M; wyn = (b + wyn) % M; lsyn1[a] = 0; psyn1[v] = 0; for(int i = u; i > v; i--) { lsyn1[i] = i-1; if(i != u) psyn1[i] = 0; } for(int i=u; i<a; i++) { if(i != u) lsyn1[i] = 0; psyn1[i] = i+1; } } solve(lsyn2[u], lsyn1[u]); solve(psyn2[u], psyn1[u]); return; } int main() { int a,b,c,r1,r2; scanf("%d", &n); for(int i=1; i<=n; i++) { scanf("%d", &oj1[i]); if(oj1[i] == -1) r1 = i; else if(oj1[i] > i) lsyn1[oj1[i]] = i; else psyn1[oj1[i]] = i; } for(int i=1; i<=n; i++) { scanf("%d", &oj2[i]); if(oj2[i] == -1) r2 = i; else if(oj2[i] > i) lsyn2[oj2[i]] = i; else psyn2[oj2[i]] = i; } solve(r2, r1); printf("%d", wyn); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | #include<bits/stdc++.h> using namespace std; int oj1[500009], lsyn1[500009], psyn1[500009], oj2[500009], lsyn2[500009], psyn2[500009], sz[500009], wyn = 0, n; const int M = 1e9 + 7; int cost(int u, int a, int poz, int oj) { int b = 0, c, d = 0; if(lsyn1[u] == 0 && psyn1[u] == 0) { sz[u] = 1; return 0; } if(u == a) { if(oj == 1) { if(psyn1[u] != 0) b = cost(psyn1[u], a, poz, 2); else sz[psyn1[u]] = 0; sz[u] = (sz[psyn1[u]] + 1) % M; return (b + sz[psyn1[u]]) % M; } else { if(lsyn1[u] != 0) b = cost(lsyn1[u], a, poz, 1); else sz[lsyn1[u]] = 0; sz[u] = (sz[lsyn1[u]] + 1) % M; return (b + sz[lsyn1[u]]) % M; } } if(lsyn1[u] != 0) b = cost(lsyn1[u], a, poz, 1); else sz[lsyn1[u]] = 0; if(psyn1[u] != 0) d = cost(psyn1[u], a, poz, 2); else sz[psyn1[u]] = 0; b = (b + d) % M; sz[u] = (sz[lsyn1[u]] + sz[psyn1[u]] + 1) % M; if(oj == 1) b = (b + sz[psyn1[u]]) % M; else b = (b + sz[lsyn1[u]]) % M; return b; } void solve(int u, int v) { int a, b; if(lsyn2[u] == 0 && psyn2[u] == 0) return; if(u < v) { a = v; while(a > u) a = lsyn1[a]; b = cost(lsyn1[v], a, 1, 1); wyn = (b + wyn) % M; b = (v - u) % M; wyn = (b + wyn) % M; for(int i = u; i > a; i--) { lsyn1[i] = i-1; if(i != u) psyn1[i] = 0; } psyn1[a] = 0; lsyn1[v] = 0; for(int i=u; i<v; i++) { if(i != u) lsyn1[i] = 0; psyn1[i] = i+1; } } else if(u > v) { a = v; while(a < u) a = psyn1[a]; b = cost(psyn1[v], a, 2, 2); wyn = (b + wyn) % M; b = (u - v) % M; wyn = (b + wyn) % M; lsyn1[a] = 0; psyn1[v] = 0; for(int i = u; i > v; i--) { lsyn1[i] = i-1; if(i != u) psyn1[i] = 0; } for(int i=u; i<a; i++) { if(i != u) lsyn1[i] = 0; psyn1[i] = i+1; } } solve(lsyn2[u], lsyn1[u]); solve(psyn2[u], psyn1[u]); return; } int main() { int a,b,c,r1,r2; scanf("%d", &n); for(int i=1; i<=n; i++) { scanf("%d", &oj1[i]); if(oj1[i] == -1) r1 = i; else if(oj1[i] > i) lsyn1[oj1[i]] = i; else psyn1[oj1[i]] = i; } for(int i=1; i<=n; i++) { scanf("%d", &oj2[i]); if(oj2[i] == -1) r2 = i; else if(oj2[i] > i) lsyn2[oj2[i]] = i; else psyn2[oj2[i]] = i; } solve(r2, r1); printf("%d", wyn); return 0; } |