1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

const int Q = ((int)1e9) + 7;

void dod(int& x, int y) {
  x += y;
	if (x >= Q) x -= Q;
}

struct drz {
	drz(int n_) : n(n_), o(n), ls(n, -1), ps(n, -1), mi(n), ma(n), pro(n) {}
	
	void wczyt() {
		for (int i = 0; i < n; i++) {
			scanf("%d", &o[i]);
			if (o[i] == -1) {
				r = i;
				continue;
			}
			o[i]--;
			if (i < o[i]) ls[o[i]]=i;
			else ps[o[i]]=i;
		}
	//	printf("root=%d %d %d\n", r, (int)o.size(), (int)ls.size());
		ustal(r);
	}

	void ustal(int x) {
		mi[x] = ma[x] = x;
		pro[x] = 0;
		if (ls[x] != -1) {
			ustal(ls[x]);
			mi[x] = mi[ls[x]];
			pro[x] += pro[ls[x]];
		}
		if (ps[x] != -1) {
			ustal(ps[x]);
			ma[x] = ma[ps[x]];
			pro[x] += pro[ps[x]];
		}
		int s = ps[x];
		if (o[x] < x) s = ls[x];
		if (s != -1) pro[x] += ma[s]-mi[s]+1;
		//printf("out %d :: %d\n",x, pro[x]);
	}

	int n;
	int r;
	vector<int> o, ls, ps, mi, ma, pro;
};

int diff(drz& A, int x, drz& B, int y) {
	int ret = std::max(x-y,y-x);
	{int lx = A.ls[x], ly = B.ls[y];
	while(lx != -1 && ly != -1 && A.ma[lx] != B.ma[ly]) {
		if (A.ma[lx] < B.ma[ly]) {
			int p = B.ps[ly];
			if (p != -1) ret += B.pro[p] + B.ma[p]-B.mi[p]+1;
			ly = B.ls[ly];	
		} else {
			int p = A.ps[lx];
			if (p != -1) ret += A.pro[p] + A.ma[p]-A.mi[p]+1;
			lx = A.ls[lx];
		}
	}
	if (lx == -1 && ly != -1) {
		ret += B.pro[ly];
		ly = -1;
	}
	if (ly == -1 && lx != -1) {
		ret += A.pro[lx];
		lx = -1;
	}
	if (lx != -1) {
		ret += diff(A, lx, B, ly);
	}
}
	int px = A.ps[x], py = B.ps[y];
	while(px != -1 && py != -1 && A.mi[px] != B.mi[py]) {
		if (A.mi[px] > B.mi[py]) {
			int l = B.ls[py];
			if (l != -1) ret += B.pro[l] + B.ma[l]-B.mi[l]+1;
			py = B.ps[py];	
		} else {
			int l = A.ls[px];
			if (l != -1) ret += A.pro[l] + A.ma[l]-A.mi[l]+1;
			px = A.ps[px];
		}
	}
	if (px == -1 && py != -1) {
		ret += B.pro[py];
		py = -1;
	}
	if (py == -1 && px != -1) {
		ret += A.pro[px];
		px = -1;
	}
	if (px != -1) {
		ret += diff(A, px, B, py);
	}
	return ret;
}

int main() {
	int n;
	scanf("%d", &n);
	drz A(n), B(n);
	A.wczyt();
	B.wczyt();
	int ret = diff(A, A.r, B, B.r);
	printf("%d\n",ret);
	return 0;
}