#include <cstdio> #include <cstdlib> #include <vector> #include <csignal> using namespace std; typedef long long LL; const int P = 1000000007; struct Node { int val; Node *left = nullptr, *right = nullptr; }; int n; LL unify(Node* a, Node* b); LL unfoldlefttill(Node* a, Node* b); LL unfoldleft(Node *a); LL unfoldrighttill(Node* a, Node* b); LL unfoldright(Node *a); LL unfoldlefttill(Node* a, Node* b) { LL res = 0; while (a->val != b->val) { if (a->val > b->val) { res = (res + a->val - a->left->val - 1) % P; a = a->left; res = (res + unfoldright(a)); } else { res = (res + b->val - b->left->val - 1) % P; b = b->left; res = (res + unfoldright(b)); } } return (res + unify(a->left, b->left)) % P; } LL unfoldleft(Node *a) { LL res = 0; while (a->left != nullptr) { res = (res + a->val - a->left->val - 1) % P; a = a->left; res = (res + unfoldright(a)) % P; } return res; } // Copied and replaced LL unfoldrighttill(Node* a, Node* b) { LL res = 0; while (a->val != b->val) { if (a->val < b->val) { res = (res + a->right->val - a->val - 1) % P; a = a->right; res = (res + unfoldleft(a)); } else { res = (res + b->right->val - b->val - 1) % P; b = b->right; res = (res + unfoldleft(b)); } } return (res + unify(a->right, b->right)) % P; } LL unfoldright(Node *a) { LL res = 0; while (a->right != nullptr) { res = (res + a->right->val - a->val - 1) % P; a = a->right; res = (res + unfoldleft(a)) % P; } return res; } // End of copied and replaced LL unify(Node* a, Node* b) { if (a == nullptr && b == nullptr) return 0; if (a == nullptr || b == nullptr) { printf("ERROR\n"); raise(SIGINT); exit(1); } if (a->val == b->val) { return (unify(a->left, b->left) + unify(a->right, b->right)) % P; } else if (a->val > b->val) { return (unfoldlefttill(a, b) + a->val - b->val + unfoldrighttill(a, b)) % P; } else { return (unfoldlefttill(a, b) + b->val - a->val + unfoldrighttill(a, b)) % P; } } Node* process_input(vector<Node>& v) { v.resize(n); Node* root; for (int i=0; i<n; ++i) { v[i].val = i; int p; scanf("%d", &p); --p; if (p < 0) { root = &v[i]; } else if (p < i) { v[p].right = &v[i]; } else { v[p].left = &v[i]; } } return root; } int main() { scanf("%d", &n); vector<Node> ta, tb; Node* a = process_input(ta); Node* b = process_input(tb); printf("%lld\n", unify(a, b)); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | #include <cstdio> #include <cstdlib> #include <vector> #include <csignal> using namespace std; typedef long long LL; const int P = 1000000007; struct Node { int val; Node *left = nullptr, *right = nullptr; }; int n; LL unify(Node* a, Node* b); LL unfoldlefttill(Node* a, Node* b); LL unfoldleft(Node *a); LL unfoldrighttill(Node* a, Node* b); LL unfoldright(Node *a); LL unfoldlefttill(Node* a, Node* b) { LL res = 0; while (a->val != b->val) { if (a->val > b->val) { res = (res + a->val - a->left->val - 1) % P; a = a->left; res = (res + unfoldright(a)); } else { res = (res + b->val - b->left->val - 1) % P; b = b->left; res = (res + unfoldright(b)); } } return (res + unify(a->left, b->left)) % P; } LL unfoldleft(Node *a) { LL res = 0; while (a->left != nullptr) { res = (res + a->val - a->left->val - 1) % P; a = a->left; res = (res + unfoldright(a)) % P; } return res; } // Copied and replaced LL unfoldrighttill(Node* a, Node* b) { LL res = 0; while (a->val != b->val) { if (a->val < b->val) { res = (res + a->right->val - a->val - 1) % P; a = a->right; res = (res + unfoldleft(a)); } else { res = (res + b->right->val - b->val - 1) % P; b = b->right; res = (res + unfoldleft(b)); } } return (res + unify(a->right, b->right)) % P; } LL unfoldright(Node *a) { LL res = 0; while (a->right != nullptr) { res = (res + a->right->val - a->val - 1) % P; a = a->right; res = (res + unfoldleft(a)) % P; } return res; } // End of copied and replaced LL unify(Node* a, Node* b) { if (a == nullptr && b == nullptr) return 0; if (a == nullptr || b == nullptr) { printf("ERROR\n"); raise(SIGINT); exit(1); } if (a->val == b->val) { return (unify(a->left, b->left) + unify(a->right, b->right)) % P; } else if (a->val > b->val) { return (unfoldlefttill(a, b) + a->val - b->val + unfoldrighttill(a, b)) % P; } else { return (unfoldlefttill(a, b) + b->val - a->val + unfoldrighttill(a, b)) % P; } } Node* process_input(vector<Node>& v) { v.resize(n); Node* root; for (int i=0; i<n; ++i) { v[i].val = i; int p; scanf("%d", &p); --p; if (p < 0) { root = &v[i]; } else if (p < i) { v[p].right = &v[i]; } else { v[p].left = &v[i]; } } return root; } int main() { scanf("%d", &n); vector<Node> ta, tb; Node* a = process_input(ta); Node* b = process_input(tb); printf("%lld\n", unify(a, b)); return 0; } |