1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <csignal>

using namespace std;

typedef long long LL;

const int P = 1000000007;

struct Node {
  int val;
  Node *left = nullptr, *right = nullptr;
};

int n;

LL unify(Node* a, Node* b);
LL unfoldlefttill(Node* a, Node* b);
LL unfoldleft(Node *a);
LL unfoldrighttill(Node* a, Node* b);
LL unfoldright(Node *a);

LL unfoldlefttill(Node* a, Node* b) {
  LL res = 0;
  while (a->val != b->val) {
    if (a->val > b->val) {
      res = (res + a->val - a->left->val - 1) % P;
      a = a->left;
      res = (res + unfoldright(a));
    } else {
      res = (res + b->val - b->left->val - 1) % P;
      b = b->left;
      res = (res + unfoldright(b));
    }
  }
  return (res + unify(a->left, b->left)) % P;
}

LL unfoldleft(Node *a) {
  LL res = 0;
  while (a->left != nullptr) {
    res = (res + a->val - a->left->val - 1) % P;
    a = a->left;
    res = (res + unfoldright(a)) % P;
  }
  return res;
}

// Copied and replaced

LL unfoldrighttill(Node* a, Node* b) {
  LL res = 0;
  while (a->val != b->val) {
    if (a->val < b->val) {
      res = (res + a->right->val - a->val - 1) % P;
      a = a->right;
      res = (res + unfoldleft(a));
    } else {
      res = (res + b->right->val - b->val - 1) % P;
      b = b->right;
      res = (res + unfoldleft(b));
    }
  }
  return (res + unify(a->right, b->right)) % P;
}

LL unfoldright(Node *a) {
  LL res = 0;
  while (a->right != nullptr) {
    res = (res + a->right->val - a->val - 1) % P;
    a = a->right;
    res = (res + unfoldleft(a)) % P;
  }
  return res;
}

// End of copied and replaced

LL unify(Node* a, Node* b) {
  if (a == nullptr && b == nullptr) return 0;
  if (a == nullptr || b == nullptr) {
    printf("ERROR\n");
    raise(SIGINT);
    exit(1);
  }
  if (a->val == b->val) {
    return (unify(a->left, b->left) + unify(a->right, b->right)) % P;
  } else if (a->val > b->val) {
    return (unfoldlefttill(a, b) + a->val - b->val + unfoldrighttill(a, b)) % P;
  } else {
    return (unfoldlefttill(a, b) + b->val - a->val + unfoldrighttill(a, b)) % P;
  }
}

Node* process_input(vector<Node>& v) {
  v.resize(n);
  Node* root;
  for (int i=0; i<n; ++i) {
    v[i].val = i;
    int p; scanf("%d", &p); --p;
    if (p < 0) {
      root = &v[i];
    } else if (p < i) {
      v[p].right = &v[i];
    } else {
      v[p].left = &v[i];
    }
  }
  return root;
}

int main() {
  scanf("%d", &n);

  vector<Node> ta, tb;
  Node* a = process_input(ta);
  Node* b = process_input(tb);

  printf("%lld\n", unify(a, b));
  return 0;
}