#include<bits/stdc++.h> using namespace std; using ll = long long; const ll mod = 1e9+7; const int N = 5e5+1; int n; ll ans; struct BST { vector<int> a, e[2]; int root; BST() { a.resize(n+1); e[0].resize(n+1); e[1].resize(n+1); for(int i=1; i<=n; ++i) { e[0][i] = e[1][i] = -1; } for(int i=1; i<=n; ++i) { cin>>a[i]; if(a[i]==-1) root = i; else { if(i < a[i]) { e[0][a[i]] = i; } else { e[1][a[i]] = i; } } } } int sort(int node, bool dir, int r2, bool& found, bool add, int& max_sorted, int& min_sorted) { if(node==-1) return 0LL; if(node==r2) found=1; max_sorted = max(max_sorted, node); min_sorted = min(min_sorted, node); int subtree_size = sort(e[dir][node], dir, r2, found, 0, max_sorted, min_sorted) + sort(e[!dir][node], !dir, r2, found, 1, max_sorted, min_sorted) + 1; if(add) { ans = (ans + subtree_size)%mod; } return subtree_size; } void transform(BST& T, int r1, int r2, int sorted) { // sorted -> ostatni ulozony na linii if(r2==-1) return; int max_sorted = max(r1, sorted), min_sorted = min(r1, sorted); if(min_sorted>r2 || max_sorted<r2) { bool found = 0, dir = (r1 < r2); int node = e[dir][sorted]; while(!found) { if(node==r2) found=1; max_sorted = max(max_sorted, node); min_sorted = min(min_sorted, node); assert(node != -1); sort(e[!dir][node], !dir, r2, found, 1, max_sorted, min_sorted); e[!dir][node] = -1; node = e[dir][node]; } } ans = (ans + abs(r1-r2))%mod; if(r2-1>=min_sorted) { transform(T, r2-1, T.e[0][r2], min_sorted); } else { transform(T, e[0][r2], T.e[0][r2], e[0][r2]); } if(r2+1<=max_sorted) { transform(T, r2+1, T.e[1][r2], max_sorted); } else { transform(T, e[1][r2], T.e[1][r2], e[1][r2]); } } }; int main() { ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0); cin>>n; BST S, T; S.transform(T, S.root, T.root, S.root); cout<<ans<<'\n'; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | #include<bits/stdc++.h> using namespace std; using ll = long long; const ll mod = 1e9+7; const int N = 5e5+1; int n; ll ans; struct BST { vector<int> a, e[2]; int root; BST() { a.resize(n+1); e[0].resize(n+1); e[1].resize(n+1); for(int i=1; i<=n; ++i) { e[0][i] = e[1][i] = -1; } for(int i=1; i<=n; ++i) { cin>>a[i]; if(a[i]==-1) root = i; else { if(i < a[i]) { e[0][a[i]] = i; } else { e[1][a[i]] = i; } } } } int sort(int node, bool dir, int r2, bool& found, bool add, int& max_sorted, int& min_sorted) { if(node==-1) return 0LL; if(node==r2) found=1; max_sorted = max(max_sorted, node); min_sorted = min(min_sorted, node); int subtree_size = sort(e[dir][node], dir, r2, found, 0, max_sorted, min_sorted) + sort(e[!dir][node], !dir, r2, found, 1, max_sorted, min_sorted) + 1; if(add) { ans = (ans + subtree_size)%mod; } return subtree_size; } void transform(BST& T, int r1, int r2, int sorted) { // sorted -> ostatni ulozony na linii if(r2==-1) return; int max_sorted = max(r1, sorted), min_sorted = min(r1, sorted); if(min_sorted>r2 || max_sorted<r2) { bool found = 0, dir = (r1 < r2); int node = e[dir][sorted]; while(!found) { if(node==r2) found=1; max_sorted = max(max_sorted, node); min_sorted = min(min_sorted, node); assert(node != -1); sort(e[!dir][node], !dir, r2, found, 1, max_sorted, min_sorted); e[!dir][node] = -1; node = e[dir][node]; } } ans = (ans + abs(r1-r2))%mod; if(r2-1>=min_sorted) { transform(T, r2-1, T.e[0][r2], min_sorted); } else { transform(T, e[0][r2], T.e[0][r2], e[0][r2]); } if(r2+1<=max_sorted) { transform(T, r2+1, T.e[1][r2], max_sorted); } else { transform(T, e[1][r2], T.e[1][r2], e[1][r2]); } } }; int main() { ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0); cin>>n; BST S, T; S.transform(T, S.root, T.root, S.root); cout<<ans<<'\n'; } |