#include<bits/stdc++.h>
using namespace std;
int oj1[500009], lsyn1[500009], psyn1[500009], oj2[500009], lsyn2[500009], psyn2[500009], sz[500009], wyn = 0, n;
const int M = 1e9 + 7;
int cost(int u, int a, int poz, int oj)
{
int b = 0, c, d = 0;
if(lsyn1[u] == 0 && psyn1[u] == 0)
{
sz[u] = 1;
return 0;
}
if(u == a)
{
if(oj == 1)
{
if(psyn1[u] != 0) b = cost(psyn1[u], a, poz, 2);
else sz[psyn1[u]] = 0;
sz[u] = (sz[psyn1[u]] + 1) % M;
return (b + sz[psyn1[u]]) % M;
}
else
{
if(lsyn1[u] != 0) b = cost(lsyn1[u], a, poz, 1);
else sz[lsyn1[u]] = 0;
sz[u] = (sz[lsyn1[u]] + 1) % M;
return (b + sz[lsyn1[u]]) % M;
}
}
if(lsyn1[u] != 0) b = cost(lsyn1[u], a, poz, 1);
else sz[lsyn1[u]] = 0;
if(psyn1[u] != 0) d = cost(psyn1[u], a, poz, 2);
else sz[psyn1[u]] = 0;
b = (b + d) % M;
sz[u] = (sz[lsyn1[u]] + sz[psyn1[u]] + 1) % M;
if(oj == 1) b = (b + sz[psyn1[u]]) % M;
else b = (b + sz[lsyn1[u]]) % M;
return b;
}
void solve(int u, int v)
{
int a, b;
if(lsyn2[u] == 0 && psyn2[u] == 0) return;
if(u < v)
{
a = v;
while(a > u) a = lsyn1[a];
b = cost(lsyn1[v], a, 1, 1);
wyn = (b + wyn) % M;
b = (v - u) % M;
wyn = (b + wyn) % M;
for(int i = u; i > a; i--)
{
lsyn1[i] = i-1;
if(i != u) psyn1[i] = 0;
}
psyn1[a] = 0;
lsyn1[v] = 0;
for(int i=u; i<v; i++)
{
if(i != u) lsyn1[i] = 0;
psyn1[i] = i+1;
}
}
else if(u > v)
{
a = v;
while(a < u) a = psyn1[a];
b = cost(psyn1[v], a, 2, 2);
wyn = (b + wyn) % M;
b = (u - v) % M;
wyn = (b + wyn) % M;
lsyn1[a] = 0;
psyn1[v] = 0;
for(int i = u; i > v; i--)
{
lsyn1[i] = i-1;
if(i != u) psyn1[i] = 0;
}
for(int i=u; i<a; i++)
{
if(i != u) lsyn1[i] = 0;
psyn1[i] = i+1;
}
}
solve(lsyn2[u], lsyn1[u]);
solve(psyn2[u], psyn1[u]);
return;
}
int main()
{
int a,b,c,r1,r2;
scanf("%d", &n);
for(int i=1; i<=n; i++)
{
scanf("%d", &oj1[i]);
if(oj1[i] == -1) r1 = i;
else if(oj1[i] > i) lsyn1[oj1[i]] = i;
else psyn1[oj1[i]] = i;
}
for(int i=1; i<=n; i++)
{
scanf("%d", &oj2[i]);
if(oj2[i] == -1) r2 = i;
else if(oj2[i] > i) lsyn2[oj2[i]] = i;
else psyn2[oj2[i]] = i;
}
solve(r2, r1);
printf("%d", wyn);
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | #include<bits/stdc++.h> using namespace std; int oj1[500009], lsyn1[500009], psyn1[500009], oj2[500009], lsyn2[500009], psyn2[500009], sz[500009], wyn = 0, n; const int M = 1e9 + 7; int cost(int u, int a, int poz, int oj) { int b = 0, c, d = 0; if(lsyn1[u] == 0 && psyn1[u] == 0) { sz[u] = 1; return 0; } if(u == a) { if(oj == 1) { if(psyn1[u] != 0) b = cost(psyn1[u], a, poz, 2); else sz[psyn1[u]] = 0; sz[u] = (sz[psyn1[u]] + 1) % M; return (b + sz[psyn1[u]]) % M; } else { if(lsyn1[u] != 0) b = cost(lsyn1[u], a, poz, 1); else sz[lsyn1[u]] = 0; sz[u] = (sz[lsyn1[u]] + 1) % M; return (b + sz[lsyn1[u]]) % M; } } if(lsyn1[u] != 0) b = cost(lsyn1[u], a, poz, 1); else sz[lsyn1[u]] = 0; if(psyn1[u] != 0) d = cost(psyn1[u], a, poz, 2); else sz[psyn1[u]] = 0; b = (b + d) % M; sz[u] = (sz[lsyn1[u]] + sz[psyn1[u]] + 1) % M; if(oj == 1) b = (b + sz[psyn1[u]]) % M; else b = (b + sz[lsyn1[u]]) % M; return b; } void solve(int u, int v) { int a, b; if(lsyn2[u] == 0 && psyn2[u] == 0) return; if(u < v) { a = v; while(a > u) a = lsyn1[a]; b = cost(lsyn1[v], a, 1, 1); wyn = (b + wyn) % M; b = (v - u) % M; wyn = (b + wyn) % M; for(int i = u; i > a; i--) { lsyn1[i] = i-1; if(i != u) psyn1[i] = 0; } psyn1[a] = 0; lsyn1[v] = 0; for(int i=u; i<v; i++) { if(i != u) lsyn1[i] = 0; psyn1[i] = i+1; } } else if(u > v) { a = v; while(a < u) a = psyn1[a]; b = cost(psyn1[v], a, 2, 2); wyn = (b + wyn) % M; b = (u - v) % M; wyn = (b + wyn) % M; lsyn1[a] = 0; psyn1[v] = 0; for(int i = u; i > v; i--) { lsyn1[i] = i-1; if(i != u) psyn1[i] = 0; } for(int i=u; i<a; i++) { if(i != u) lsyn1[i] = 0; psyn1[i] = i+1; } } solve(lsyn2[u], lsyn1[u]); solve(psyn2[u], psyn1[u]); return; } int main() { int a,b,c,r1,r2; scanf("%d", &n); for(int i=1; i<=n; i++) { scanf("%d", &oj1[i]); if(oj1[i] == -1) r1 = i; else if(oj1[i] > i) lsyn1[oj1[i]] = i; else psyn1[oj1[i]] = i; } for(int i=1; i<=n; i++) { scanf("%d", &oj2[i]); if(oj2[i] == -1) r2 = i; else if(oj2[i] > i) lsyn2[oj2[i]] = i; else psyn2[oj2[i]] = i; } solve(r2, r1); printf("%d", wyn); return 0; } |
English