#include <cassert>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
#define MAXN 500007
//#define DEBUGPRINT
//#define DEBUGASSERT
#ifdef DEBUGPRINT
#define DEBPRINT(...) fprintf(stderr, __VA_ARGS__)
#else
#define DEBPRINT(...)
#endif
#ifdef DEBUGASSERT
#define ASSERT(x) assert(x)
#else
#define ASSERT(x)
#endif
typedef long long slng;
int N, M;
vector<int> E[MAXN];
int minr[MAXN], maxr[MAXN];
enum typ {TYPEMIN=0, TYPEMAX=1};
struct pp {
int r;
typ type;
};
bool operator<(const pp& lhs, const pp& rhs) {
if (lhs.r != rhs.r)
return lhs.r < rhs.r;
return lhs.type < rhs.type;
}
slng solve(int n, int parent) {
int i;
slng ret_child, ret1, ret2;
vector<pp> p;
int left,in,right;
if (n < M)
return 0; // leaf
ret_child = 0;
for (i = 0; i < E[n].size(); i++) {
int k = E[n][i];
if (k != parent) {
ret_child += solve(k, n);
}
}
for (i = 0; i < E[n].size(); i++) {
int k = E[n][i];
if (k != parent) {
p.push_back({minr[k], TYPEMIN});
p.push_back({maxr[k], TYPEMAX});
}
}
sort(p.begin(), p.end());
left = 0; in = 0; right = p.size()/2;
for (i = 0; i < p.size(); i++) {
DEBPRINT("(%d %d %d)[%d %d]", left, in, right, p[i].r, p[i].type);
if (p[i].type == TYPEMIN) {
right--;
in++;
} else if (p[i].type == TYPEMAX) {
in--;
left++;
}
if (left == right) {
minr[n] = p[i].r;
}
if (left == right + 1) {
maxr[n] = p[i].r;
}
}
DEBPRINT("\nminr,maxr[%d] = [%d, %d]\n", n, minr[n], maxr[n]);
ret1 = 0; ret2 = 0;
for (i = 0; i < E[n].size(); i++) {
int k = E[n][i];
if (k != parent) {
if (maxr[k] < minr[n]) {
ret1 += minr[n] - maxr[k];
} else if (minr[k] > minr[n]) {
ret1 += minr[k] - minr[n];
}
#ifdef DEBUGASSERT
if (maxr[k] < maxr[n]) {
ret2 += maxr[n] - maxr[k];
} else if (minr[k] > minr[n]) {
ret2 += minr[k] - maxr[n];
}
#endif
}
}
ASSERT(ret1 == ret2);
return ret_child + ret1;
}
int main() {
int i;
slng ret;
scanf("%d%d", &N, &M);
for (i = 0; i < N-1; i++) {
int a,b;
scanf("%d%d", &a, &b);
E[a-1].push_back(b-1);
E[b-1].push_back(a-1);
}
for (i = 0; i < M; i++) {
scanf("%d", &minr[i]);
maxr[i] = minr[i];
}
printf("%lld\n", solve(M, -1));
return 0;
}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | #include <cassert> #include <cstdio> #include <vector> #include <algorithm> using namespace std; #define MAXN 500007 //#define DEBUGPRINT //#define DEBUGASSERT #ifdef DEBUGPRINT #define DEBPRINT(...) fprintf(stderr, __VA_ARGS__) #else #define DEBPRINT(...) #endif #ifdef DEBUGASSERT #define ASSERT(x) assert(x) #else #define ASSERT(x) #endif typedef long long slng; int N, M; vector<int> E[MAXN]; int minr[MAXN], maxr[MAXN]; enum typ {TYPEMIN=0, TYPEMAX=1}; struct pp { int r; typ type; }; bool operator<(const pp& lhs, const pp& rhs) { if (lhs.r != rhs.r) return lhs.r < rhs.r; return lhs.type < rhs.type; } slng solve(int n, int parent) { int i; slng ret_child, ret1, ret2; vector<pp> p; int left,in,right; if (n < M) return 0; // leaf ret_child = 0; for (i = 0; i < E[n].size(); i++) { int k = E[n][i]; if (k != parent) { ret_child += solve(k, n); } } for (i = 0; i < E[n].size(); i++) { int k = E[n][i]; if (k != parent) { p.push_back({minr[k], TYPEMIN}); p.push_back({maxr[k], TYPEMAX}); } } sort(p.begin(), p.end()); left = 0; in = 0; right = p.size()/2; for (i = 0; i < p.size(); i++) { DEBPRINT("(%d %d %d)[%d %d]", left, in, right, p[i].r, p[i].type); if (p[i].type == TYPEMIN) { right--; in++; } else if (p[i].type == TYPEMAX) { in--; left++; } if (left == right) { minr[n] = p[i].r; } if (left == right + 1) { maxr[n] = p[i].r; } } DEBPRINT("\nminr,maxr[%d] = [%d, %d]\n", n, minr[n], maxr[n]); ret1 = 0; ret2 = 0; for (i = 0; i < E[n].size(); i++) { int k = E[n][i]; if (k != parent) { if (maxr[k] < minr[n]) { ret1 += minr[n] - maxr[k]; } else if (minr[k] > minr[n]) { ret1 += minr[k] - minr[n]; } #ifdef DEBUGASSERT if (maxr[k] < maxr[n]) { ret2 += maxr[n] - maxr[k]; } else if (minr[k] > minr[n]) { ret2 += minr[k] - maxr[n]; } #endif } } ASSERT(ret1 == ret2); return ret_child + ret1; } int main() { int i; slng ret; scanf("%d%d", &N, &M); for (i = 0; i < N-1; i++) { int a,b; scanf("%d%d", &a, &b); E[a-1].push_back(b-1); E[b-1].push_back(a-1); } for (i = 0; i < M; i++) { scanf("%d", &minr[i]); maxr[i] = minr[i]; } printf("%lld\n", solve(M, -1)); return 0; } |
English