#include <cstdio>
#include <cstdlib>
#include <vector>
#include <csignal>
using namespace std;
typedef long long LL;
const int P = 1000000007;
struct Node {
int val;
Node *left = nullptr, *right = nullptr;
};
int n;
LL unify(Node* a, Node* b);
LL unfoldlefttill(Node* a, Node* b);
LL unfoldleft(Node *a);
LL unfoldrighttill(Node* a, Node* b);
LL unfoldright(Node *a);
LL unfoldlefttill(Node* a, Node* b) {
LL res = 0;
while (a->val != b->val) {
if (a->val > b->val) {
res = (res + a->val - a->left->val - 1) % P;
a = a->left;
res = (res + unfoldright(a));
} else {
res = (res + b->val - b->left->val - 1) % P;
b = b->left;
res = (res + unfoldright(b));
}
}
return (res + unify(a->left, b->left)) % P;
}
LL unfoldleft(Node *a) {
LL res = 0;
while (a->left != nullptr) {
res = (res + a->val - a->left->val - 1) % P;
a = a->left;
res = (res + unfoldright(a)) % P;
}
return res;
}
// Copied and replaced
LL unfoldrighttill(Node* a, Node* b) {
LL res = 0;
while (a->val != b->val) {
if (a->val < b->val) {
res = (res + a->right->val - a->val - 1) % P;
a = a->right;
res = (res + unfoldleft(a));
} else {
res = (res + b->right->val - b->val - 1) % P;
b = b->right;
res = (res + unfoldleft(b));
}
}
return (res + unify(a->right, b->right)) % P;
}
LL unfoldright(Node *a) {
LL res = 0;
while (a->right != nullptr) {
res = (res + a->right->val - a->val - 1) % P;
a = a->right;
res = (res + unfoldleft(a)) % P;
}
return res;
}
// End of copied and replaced
LL unify(Node* a, Node* b) {
if (a == nullptr && b == nullptr) return 0;
if (a == nullptr || b == nullptr) {
printf("ERROR\n");
raise(SIGINT);
exit(1);
}
if (a->val == b->val) {
return (unify(a->left, b->left) + unify(a->right, b->right)) % P;
} else if (a->val > b->val) {
return (unfoldlefttill(a, b) + a->val - b->val + unfoldrighttill(a, b)) % P;
} else {
return (unfoldlefttill(a, b) + b->val - a->val + unfoldrighttill(a, b)) % P;
}
}
Node* process_input(vector<Node>& v) {
v.resize(n);
Node* root;
for (int i=0; i<n; ++i) {
v[i].val = i;
int p; scanf("%d", &p); --p;
if (p < 0) {
root = &v[i];
} else if (p < i) {
v[p].right = &v[i];
} else {
v[p].left = &v[i];
}
}
return root;
}
int main() {
scanf("%d", &n);
vector<Node> ta, tb;
Node* a = process_input(ta);
Node* b = process_input(tb);
printf("%lld\n", unify(a, b));
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | #include <cstdio> #include <cstdlib> #include <vector> #include <csignal> using namespace std; typedef long long LL; const int P = 1000000007; struct Node { int val; Node *left = nullptr, *right = nullptr; }; int n; LL unify(Node* a, Node* b); LL unfoldlefttill(Node* a, Node* b); LL unfoldleft(Node *a); LL unfoldrighttill(Node* a, Node* b); LL unfoldright(Node *a); LL unfoldlefttill(Node* a, Node* b) { LL res = 0; while (a->val != b->val) { if (a->val > b->val) { res = (res + a->val - a->left->val - 1) % P; a = a->left; res = (res + unfoldright(a)); } else { res = (res + b->val - b->left->val - 1) % P; b = b->left; res = (res + unfoldright(b)); } } return (res + unify(a->left, b->left)) % P; } LL unfoldleft(Node *a) { LL res = 0; while (a->left != nullptr) { res = (res + a->val - a->left->val - 1) % P; a = a->left; res = (res + unfoldright(a)) % P; } return res; } // Copied and replaced LL unfoldrighttill(Node* a, Node* b) { LL res = 0; while (a->val != b->val) { if (a->val < b->val) { res = (res + a->right->val - a->val - 1) % P; a = a->right; res = (res + unfoldleft(a)); } else { res = (res + b->right->val - b->val - 1) % P; b = b->right; res = (res + unfoldleft(b)); } } return (res + unify(a->right, b->right)) % P; } LL unfoldright(Node *a) { LL res = 0; while (a->right != nullptr) { res = (res + a->right->val - a->val - 1) % P; a = a->right; res = (res + unfoldleft(a)) % P; } return res; } // End of copied and replaced LL unify(Node* a, Node* b) { if (a == nullptr && b == nullptr) return 0; if (a == nullptr || b == nullptr) { printf("ERROR\n"); raise(SIGINT); exit(1); } if (a->val == b->val) { return (unify(a->left, b->left) + unify(a->right, b->right)) % P; } else if (a->val > b->val) { return (unfoldlefttill(a, b) + a->val - b->val + unfoldrighttill(a, b)) % P; } else { return (unfoldlefttill(a, b) + b->val - a->val + unfoldrighttill(a, b)) % P; } } Node* process_input(vector<Node>& v) { v.resize(n); Node* root; for (int i=0; i<n; ++i) { v[i].val = i; int p; scanf("%d", &p); --p; if (p < 0) { root = &v[i]; } else if (p < i) { v[p].right = &v[i]; } else { v[p].left = &v[i]; } } return root; } int main() { scanf("%d", &n); vector<Node> ta, tb; Node* a = process_input(ta); Node* b = process_input(tb); printf("%lld\n", unify(a, b)); return 0; } |
English