#include <cassert> #include <cstdio> #include <vector> #include <algorithm> using namespace std; #define MAXN 500007 //#define DEBUGPRINT //#define DEBUGASSERT #ifdef DEBUGPRINT #define DEBPRINT(...) fprintf(stderr, __VA_ARGS__) #else #define DEBPRINT(...) #endif #ifdef DEBUGASSERT #define ASSERT(x) assert(x) #else #define ASSERT(x) #endif typedef long long slng; int N, M; vector<int> E[MAXN]; int minr[MAXN], maxr[MAXN]; enum typ {TYPEMIN=0, TYPEMAX=1}; struct pp { int r; typ type; }; bool operator<(const pp& lhs, const pp& rhs) { if (lhs.r != rhs.r) return lhs.r < rhs.r; return lhs.type < rhs.type; } slng solve(int n, int parent) { int i; slng ret_child, ret1, ret2; vector<pp> p; int left,in,right; if (n < M) return 0; // leaf ret_child = 0; for (i = 0; i < E[n].size(); i++) { int k = E[n][i]; if (k != parent) { ret_child += solve(k, n); } } for (i = 0; i < E[n].size(); i++) { int k = E[n][i]; if (k != parent) { p.push_back({minr[k], TYPEMIN}); p.push_back({maxr[k], TYPEMAX}); } } sort(p.begin(), p.end()); left = 0; in = 0; right = p.size()/2; for (i = 0; i < p.size(); i++) { DEBPRINT("(%d %d %d)[%d %d]", left, in, right, p[i].r, p[i].type); if (p[i].type == TYPEMIN) { right--; in++; } else if (p[i].type == TYPEMAX) { in--; left++; } if (left == right) { minr[n] = p[i].r; } if (left == right + 1) { maxr[n] = p[i].r; } } DEBPRINT("\nminr,maxr[%d] = [%d, %d]\n", n, minr[n], maxr[n]); ret1 = 0; ret2 = 0; for (i = 0; i < E[n].size(); i++) { int k = E[n][i]; if (k != parent) { if (maxr[k] < minr[n]) { ret1 += minr[n] - maxr[k]; } else if (minr[k] > minr[n]) { ret1 += minr[k] - minr[n]; } #ifdef DEBUGASSERT if (maxr[k] < maxr[n]) { ret2 += maxr[n] - maxr[k]; } else if (minr[k] > minr[n]) { ret2 += minr[k] - maxr[n]; } #endif } } ASSERT(ret1 == ret2); return ret_child + ret1; } int main() { int i; slng ret; scanf("%d%d", &N, &M); for (i = 0; i < N-1; i++) { int a,b; scanf("%d%d", &a, &b); E[a-1].push_back(b-1); E[b-1].push_back(a-1); } for (i = 0; i < M; i++) { scanf("%d", &minr[i]); maxr[i] = minr[i]; } printf("%lld\n", solve(M, -1)); return 0; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | #include <cassert> #include <cstdio> #include <vector> #include <algorithm> using namespace std; #define MAXN 500007 //#define DEBUGPRINT //#define DEBUGASSERT #ifdef DEBUGPRINT #define DEBPRINT(...) fprintf(stderr, __VA_ARGS__) #else #define DEBPRINT(...) #endif #ifdef DEBUGASSERT #define ASSERT(x) assert(x) #else #define ASSERT(x) #endif typedef long long slng; int N, M; vector<int> E[MAXN]; int minr[MAXN], maxr[MAXN]; enum typ {TYPEMIN=0, TYPEMAX=1}; struct pp { int r; typ type; }; bool operator<(const pp& lhs, const pp& rhs) { if (lhs.r != rhs.r) return lhs.r < rhs.r; return lhs.type < rhs.type; } slng solve(int n, int parent) { int i; slng ret_child, ret1, ret2; vector<pp> p; int left,in,right; if (n < M) return 0; // leaf ret_child = 0; for (i = 0; i < E[n].size(); i++) { int k = E[n][i]; if (k != parent) { ret_child += solve(k, n); } } for (i = 0; i < E[n].size(); i++) { int k = E[n][i]; if (k != parent) { p.push_back({minr[k], TYPEMIN}); p.push_back({maxr[k], TYPEMAX}); } } sort(p.begin(), p.end()); left = 0; in = 0; right = p.size()/2; for (i = 0; i < p.size(); i++) { DEBPRINT("(%d %d %d)[%d %d]", left, in, right, p[i].r, p[i].type); if (p[i].type == TYPEMIN) { right--; in++; } else if (p[i].type == TYPEMAX) { in--; left++; } if (left == right) { minr[n] = p[i].r; } if (left == right + 1) { maxr[n] = p[i].r; } } DEBPRINT("\nminr,maxr[%d] = [%d, %d]\n", n, minr[n], maxr[n]); ret1 = 0; ret2 = 0; for (i = 0; i < E[n].size(); i++) { int k = E[n][i]; if (k != parent) { if (maxr[k] < minr[n]) { ret1 += minr[n] - maxr[k]; } else if (minr[k] > minr[n]) { ret1 += minr[k] - minr[n]; } #ifdef DEBUGASSERT if (maxr[k] < maxr[n]) { ret2 += maxr[n] - maxr[k]; } else if (minr[k] > minr[n]) { ret2 += minr[k] - maxr[n]; } #endif } } ASSERT(ret1 == ret2); return ret_child + ret1; } int main() { int i; slng ret; scanf("%d%d", &N, &M); for (i = 0; i < N-1; i++) { int a,b; scanf("%d%d", &a, &b); E[a-1].push_back(b-1); E[b-1].push_back(a-1); } for (i = 0; i < M; i++) { scanf("%d", &minr[i]); maxr[i] = minr[i]; } printf("%lld\n", solve(M, -1)); return 0; } |